Source code for roicat.tracking.clustering

import warnings
from typing import Union, Tuple, List, Dict, Optional, Any

import numpy as np
import scipy
import scipy.optimize
import scipy.sparse
import sklearn
import sklearn.isotonic
import matplotlib.pyplot as plt
import torch
from tqdm.auto import tqdm

from .. import helpers, util
from .similarity_graph import SimilarityMetric



[docs] class Clusterer(util.ROICaT_Module): """ Class for clustering algorithms. Performs: * Optimal mixing and pruning of similarity matrices: * self.find_optimal_parameters_for_pruning() * self.make_pruned_similarity_graphs() * Clustering: * self.fit(): Which uses a modified HDBSCAN * self.fit_sequentialHungarian: Which uses a method similar to CaImAn's clustering method. * Quality control: * self.compute_cluster_quality_metrics() Initialization ingests and stores similarity matrices. RH 2023 / 2025 Args: similarities (Dict[str, scipy.sparse.csr_array]): Dict mapping metric name to sparse similarity matrix. All matrices must share the same nonzero pattern (same nnz). Example: ``{'sf': csr_array, 'nn': csr_array, 'swt': csr_array}``. metric_configs (List[SimilarityMetric]): List of ``SimilarityMetric`` dataclass instances describing each metric's optimization behavior (sparsity source, sigmoid, power). s_sesh (scipy.sparse.csr_array): Inter-session mask. Shape: *(n_rois, n_rois)*. Boolean, with 1s where the two ROIs are from different sessions. n_bins (Optional[int]): Number of bins to use for the pairwise similarity distribution. If ``None``, then a heuristic is used to estimate the value based on the number of nonzero pairs. (Default is ``None``) smoothing_window_bins (Optional[int]): Number of bins to use when smoothing the distribution. If ``None``, then a heuristic is used. (Default is ``None``) session_bool (Optional[np.ndarray]): Boolean array indicating which ROIs belong to which session. Shape: *(n_rois, n_sessions)*. (Default is ``None``) verbose (bool): Specifies whether to print out information about the clustering process. (Default is ``True``) Attributes: similarities (Dict[str, scipy.sparse.csr_array]): Dict of similarity matrices keyed by metric name. s_sesh (scipy.sparse.csr_array): The inter-session similarity matrix. Shape: *(n_rois, n_rois)*. s_sesh_inv (scipy.sparse.csr_array): Intra-session mask (True where ROIs are from the SAME session). Shape: *(n_rois, n_rois)*. n_bins (Optional[int]): Number of bins to use for the pairwise similarity distribution. smooth_window (Optional[int]): Number of bins to use when smoothing the distribution. verbose (bool): Specifies how much information to print out: \n * 0/False: Warnings only * 1/True: Basic info, progress bar * 2: All info """ def __init__( self, similarities: Dict[str, scipy.sparse.csr_array], metric_configs: List[SimilarityMetric], s_sesh: scipy.sparse.csr_array, n_bins: Optional[int] = None, smoothing_window_bins: Optional[int] = None, session_bool: Optional[np.ndarray] = None, verbose: bool = True, ): """ Initializes the Clusterer with similarity matrices and metric configs. Args: similarities (Dict[str, scipy.sparse.csr_array]): Dict mapping metric name to similarity matrix. All matrices must share the same nonzero pattern (same nnz). Example: ``{'sf': csr_array, 'nn': csr_array, 'swt': csr_array}``. metric_configs (List[SimilarityMetric]): List of metric configurations describing each metric's role in optimization (sparsity source, sigmoid, power, etc.). s_sesh (scipy.sparse.csr_array): Inter-session mask. Shape: *(n_rois, n_rois)*. Boolean, with 1s where two ROIs are from different sessions. n_bins (Optional[int]): Number of bins for pairwise similarity distributions. If ``None``, heuristic based on nnz. (Default is ``None``) smoothing_window_bins (Optional[int]): Smoothing window for distributions. If ``None``, heuristic. (Default is ``None``) session_bool (Optional[np.ndarray]): Boolean array indicating which ROIs belong to which session. Shape: *(n_rois, n_sessions)*. (Default is ``None``) verbose (bool): Verbosity level. (Default is ``True``) """ super().__init__() ## Store parameter (but not data) args as attributes self.params['__init__'] = self._locals_to_params( locals_dict=locals(), keys=['n_bins', 'smoothing_window_bins', 'verbose'], ) ## Store similarities dict and metric configs self.similarities = similarities ## Store metric configs directly as SimilarityMetric objects keyed by name. ## RichFile_ROICaT has a registered type handler for SimilarityMetric ## that serializes them as JSON via to_dict()/from_dict(). self._metric_configs_stored = {m.name: m for m in metric_configs} ## Identify sparsity source sparsity_sources = [name for name, cfg in self._metric_configs.items() if cfg.is_sparsity_source] assert len(sparsity_sources) > 0, "At least one metric must have is_sparsity_source=True" self._sparsity_name = sparsity_sources[0] self._s_sparsity = self.similarities[self._sparsity_name] ## Validate all similarities have same nnz nnz_values = {name: s.nnz for name, s in self.similarities.items()} first_nnz = next(iter(nnz_values.values())) assert all(v == first_nnz for v in nnz_values.values()), ( f"All similarity matrices must have same nnz. Got: {nnz_values}" ) ## Build session masks. s_sesh marks inter-session pairs (different ## sessions). s_sesh_inv marks intra-session pairs (same session) ## within the sparsity pattern — used by _precompute_intra_mask. self.s_sesh = s_sesh ## shape: (n_rois, n_rois), sparse bool self.s_sesh_inv = (self._s_sparsity != 0).astype(bool) ## all nonzero pairs self.s_sesh_inv[self.s_sesh.astype(bool)] = False ## remove inter-session → leaves intra only self.s_sesh_inv.eliminate_zeros() ## Zero out diagonal of s_sesh (self-pairs are neither inter nor intra) self.s_sesh = self.s_sesh.tolil() self.s_sesh[range(self.s_sesh.shape[0]), range(self.s_sesh.shape[1])] = 0 self.s_sesh = self.s_sesh.tocsr() self._verbose = verbose self.n_bins = max(min(first_nnz // 10000, 200), 20) if n_bins is None else n_bins self.smooth_window = helpers.make_odd(self.n_bins // 10, mode='up') if smoothing_window_bins is None else smoothing_window_bins self._session_bool = session_bool def __repr__(self): has_sim = hasattr(self, 'similarities') and self.similarities is not None if has_sim: first_mat = next(iter(self.similarities.values())) n_roi = first_mat.shape[0] nnz = first_mat.nnz else: n_roi, nnz = 0, 0 has_labels = hasattr(self, 'labels') and self.labels is not None n_clusters = len(set(self.labels) - {-1}) if has_labels else 0 return ( f"Clusterer(n_roi={n_roi}, nnz={nnz}, " f"metrics={list(self.similarities.keys()) if has_sim else []}, " f"n_clusters={n_clusters if has_labels else 'not fitted'})" ) @property def _metric_configs(self) -> Dict[str, SimilarityMetric]: """Return stored SimilarityMetric instances keyed by name.""" stored = self._metric_configs_stored ## After legacy pickle load with dicts, reconstruct if len(stored) > 0 and isinstance(next(iter(stored.values())), dict): return {name: SimilarityMetric.from_dict(d) for name, d in stored.items()} return stored
[docs] def find_optimal_parameters_for_pruning( self, bounds_findParameters: Optional[Dict[str, List[float]]] = None, de_kwargs: Dict[str, Any] = { 'maxiter': 100, 'tol': 1e-6, 'popsize': 15, 'mutation': (0.5, 1.5), 'recombination': 0.7, 'polish': True, }, n_bins: Optional[int] = None, smoothing_window_bins: Optional[int] = None, subsample_pairs: Optional[int] = None, seed: Optional[int] = None, ) -> Dict: """ Find optimal mixing parameters for pruning the similarity graph. Two-stage approach: 1. **Naive Bayes calibration**: For each similarity feature, estimates ``P(same | s_k)`` from histogram subtraction. The resulting per-feature calibration curves are used to analytically estimate optimal sigmoid parameters ``(mu, b)`` via Fisher's linear discriminant. 2. **Differential evolution**: With sigmoid parameters frozen from stage 1, optimizes the remaining parameters (one ``power_<name>`` per metric with ``optimize_power=True``, plus ``p_norm``) by minimizing the histogram overlap loss. This method replaces the original Optuna TPE search (see :meth:`_find_optimal_parameters_for_pruning_optuna` in the legacy section). The two-stage approach achieves better separation quality (lower histogram overlap) on typical datasets. RH 2023 / 2025 Args: bounds_findParameters (Dict[str, List[float]]): Bounds for the optimized parameters. Keys are ``power_<name>`` for each metric with ``optimize_power=True``, plus ``p_norm``. Auto-constructed from metric configs if ``None``. de_kwargs (Dict[str, Any]): Keyword arguments for ``scipy.optimize.differential_evolution``: \n * ``maxiter`` (int): Maximum number of DE generations. * ``tol`` (float): Convergence tolerance on the loss. * ``popsize`` (int): Population size multiplier (actual population = ``popsize * n_params``). * ``mutation`` (Tuple[float, float]): Differential weight range ``(min, max)`` for dithering. * ``recombination`` (float): Crossover probability in ``[0, 1]``. * ``polish`` (bool): If ``True``, run L-BFGS-B from the best DE solution. Often has no effect on piecewise-constant histogram loss. n_bins (Optional[int]): Overwrites ``n_bins`` from ``__init__``. smoothing_window_bins (Optional[int]): Overwrites ``smoothing_window_bins`` from ``__init__``. subsample_pairs (Optional[int]): If not ``None``, subsample this many pairs for histogram loss evaluation. Maintains intra/inter ratio. If ``None``, auto-computed based on pair counts. seed (Optional[int]): Random seed for reproducibility. Returns: (Dict): kwargs_makeConjunctiveDistanceMatrix_best (Dict): Optimal parameters for :meth:`make_conjunctive_distance_matrix`. """ ## Store parameter (but not data) args as attributes self.params['find_optimal_parameters_for_pruning'] = self._locals_to_params( locals_dict=locals(), keys=[ 'bounds_findParameters', 'de_kwargs', 'n_bins', 'smoothing_window_bins', 'subsample_pairs', 'seed', ], ) ## Auto-construct bounds from metric configs if not provided if bounds_findParameters is None: bounds_findParameters = {} for name, cfg in self._metric_configs.items(): if cfg.optimize_power: bounds_findParameters[f'power_{name}'] = list(cfg.power_bounds) bounds_findParameters['p_norm'] = [-5, -0.1] ## NB calibration → Fisher sigmoid estimation → N-param DE. return self._find_optimal_parameters_DE( bounds_findParameters=bounds_findParameters, de_kwargs=de_kwargs, n_bins=n_bins, smoothing_window_bins=smoothing_window_bins, subsample_pairs=subsample_pairs, seed=seed, freeze_sigmoid=True, )
#################################################################### ## Differential evolution parameter optimization #################################################################### def _precompute_intra_mask(self) -> np.ndarray: """ Build a boolean mask of shape ``(nnz,)`` indicating which nonzero entries in the sparsity source matrix correspond to intra-session (known-different) ROI pairs. Uses an index-mapping trick: assigns each nonzero entry a unique 1-based index, multiplies by the inverse session matrix to isolate intra-session entries, then reads back which indices survived. The result is stored as a numpy bool array (so that serialization via ``serializable_dict`` preserves it). Call sites that need a torch tensor should wrap with ``torch.as_tensor(self._intra_mask)``. RH 2025 Returns: (np.ndarray): intra_mask (np.ndarray): Boolean numpy array, shape ``(self._s_sparsity.nnz,)``. """ ## Map each nonzero entry to a unique 1-based index idx_mat = self._s_sparsity.copy().astype(np.float64) idx_mat.data = np.arange(1, self._s_sparsity.nnz + 1, dtype=np.float64) ## Elementwise multiply with s_sesh_inv to keep only intra-session entries masked = idx_mat.multiply(self.s_sesh_inv.astype(np.float64)) masked.eliminate_zeros() ## Convert back to 0-based indices and build boolean mask intra_indices = (masked.data - 1).astype(np.int64) mask = np.zeros(self._s_sparsity.nnz, dtype=bool) mask[intra_indices] = True ## Store as a numpy bool array; callers convert with torch.as_tensor() self._intra_mask = mask return self._intra_mask def _subsample_pairs( self, n_subsample: int, intra_mask: torch.Tensor, seed: Optional[int] = None, ) -> torch.Tensor: """ Subsample pair indices while preserving the intra/inter-session ratio. Returns indices into the full nnz-length arrays. Intra-session indices come first in the returned tensor, so the intra mask for the subsample is simply ``True`` for the first ``n_sample_intra`` entries. RH 2025 Args: n_subsample (int): Target number of pairs to sample. intra_mask (torch.Tensor): Boolean tensor, shape ``(nnz,)``. ``True`` for intra-session (known-different) pairs. seed (Optional[int]): Random seed for reproducibility. Returns: (torch.Tensor): sample_idx (torch.Tensor): 1D int64 tensor of sampled pair indices, length ``<= n_subsample``. Intra-session pairs come first. """ ## Accept numpy or torch; convert to torch for all internal operations. intra_mask = torch.as_tensor(intra_mask) nnz = intra_mask.shape[0] n_intra = int(intra_mask.sum().item()) n_inter = nnz - n_intra frac = n_subsample / nnz n_sample_intra = max(int(n_intra * frac), 100) n_sample_inter = max(int(n_inter * frac), 100) intra_idx = torch.where(intra_mask)[0] inter_idx = torch.where(~intra_mask)[0] rng = torch.Generator() rng.manual_seed(seed if seed is not None else 42) perm_intra = torch.randperm(n_intra, generator=rng)[:n_sample_intra] perm_inter = torch.randperm(n_inter, generator=rng)[:n_sample_inter] ## Intra first, then inter — intra_mask for subsample is True for ## the first n_sample_intra entries return torch.cat([intra_idx[perm_intra], inter_idx[perm_inter]]) def _find_optimal_parameters_DE( self, bounds_findParameters: Optional[Dict[str, List[float]]] = None, de_kwargs: Dict[str, Any] = { 'maxiter': 100, 'tol': 1e-6, 'popsize': 15, 'mutation': (0.5, 1.5), 'recombination': 0.7, 'polish': True, }, n_bins: Optional[int] = None, smoothing_window_bins: Optional[int] = None, subsample_pairs: Optional[int] = None, seed: Optional[int] = None, freeze_sigmoid: bool = True, ) -> Dict: """ Find optimal mixing parameters using scipy differential evolution. When ``freeze_sigmoid=True`` (default), sigmoid parameters ``(mu, b)`` are estimated from NB calibration curves via Fisher's linear discriminant and held fixed. The search is then over ``power_<name>`` for each metric with ``optimize_power=True``, plus ``p_norm``. When ``False``, sigmoid params are also optimized. The inner loop operates entirely on precomputed torch tensors — no scipy sparse operations per evaluation. When subsampling is active, the subsample is redrawn each DE generation to reduce overfitting to a specific pair subset. RH 2025 Args: bounds_findParameters (Dict[str, List[float]]): Bounds for each parameter, keyed by ``power_<name>`` and ``p_norm``. Auto-constructed from metric configs if ``None``. When ``False``, all 7 keys are needed. de_kwargs (Dict[str, Any]): Keyword arguments for ``scipy.optimize.differential_evolution``: \n * ``maxiter`` (int): Maximum number of DE generations. * ``tol`` (float): Convergence tolerance on the loss. * ``popsize`` (int): Population size multiplier (actual population = ``popsize * n_params``). * ``mutation`` (Tuple[float, float]): Differential weight range ``(min, max)`` for dithering. * ``recombination`` (float): Crossover probability in ``[0, 1]``. * ``polish`` (bool): If ``True``, run L-BFGS-B from the best DE solution. Often has no effect on piecewise-constant histogram loss. n_bins (Optional[int]): Overwrites ``n_bins`` from __init__. smoothing_window_bins (Optional[int]): Overwrites ``smoothing_window_bins`` from __init__. subsample_pairs (Optional[int]): If not ``None``, subsample this many pairs for histogram loss. Maintains intra/inter ratio. If ``None``, auto-computed: subsamples to 1.1M (100k intra + 1M inter) when there are enough pairs, otherwise uses all pairs. seed (Optional[int]): Random seed for reproducibility. freeze_sigmoid (bool): If ``True``, fix sigmoid params from NB calibration, reducing DE to 3 parameters. If ``False``, optimize all 7 parameters jointly. Returns: (Dict): kwargs_makeConjunctiveDistanceMatrix_best (Dict): Optimal parameters for :meth:`make_conjunctive_distance_matrix`. """ ## Store parameter (but not data) args as attributes self.params['_find_optimal_parameters_DE'] = self._locals_to_params( locals_dict=locals(), keys=[ 'bounds_findParameters', 'de_kwargs', 'n_bins', 'smoothing_window_bins', 'subsample_pairs', 'seed', 'freeze_sigmoid', ], ) self.n_bins = self.n_bins if n_bins is None else n_bins self.smooth_window = self.smooth_window if smoothing_window_bins is None else smoothing_window_bins ## Auto-construct bounds from metric configs if not provided if bounds_findParameters is None: bounds_findParameters = {} for name, cfg in self._metric_configs.items(): if cfg.optimize_power: bounds_findParameters[f'power_{name}'] = list(cfg.power_bounds) bounds_findParameters['p_norm'] = [-5, -0.1] ## Add sigmoid bounds for unfrozen case for name, cfg in self._metric_configs.items(): if cfg.optimize_sigmoid: bounds_findParameters[f'sig_{name}_kwargs_mu'] = [0., 1.0] bounds_findParameters[f'sig_{name}_kwargs_b'] = [0.1, 1.5] self.bounds_findParameters = bounds_findParameters self._seed = seed ################################################################ ## Auto-compute subsample size if not specified ################################################################ if not hasattr(self, '_intra_mask') or self._intra_mask is None: self._precompute_intra_mask() if subsample_pairs is None: n_intra = int(self._intra_mask.sum()) n_inter = self._s_sparsity.nnz - n_intra ## Subsample only if we have enough pairs to meet minimums min_intra = 100_000 min_inter = 1_000_000 if n_intra >= min_intra and n_inter >= min_inter: subsample_pairs = min_intra + min_inter ## Otherwise use all pairs ################################################################ ## Determine parameter layout: 3-param (frozen) or 7-param (full) ################################################################ _frozen_sig = None if freeze_sigmoid: if not hasattr(self, 'calibrations_naive_bayes') or self.calibrations_naive_bayes is None: self.make_naive_bayes_distance_matrix() sig_params = self._estimate_sigmoid_params() _frozen_sig = sig_params ## Dict[metric_name, {'mu': float, 'b': float}] if self._verbose: parts = [f'{n}(mu={p["mu"]:.3f}, b={p["b"]:.1f})' for n, p in _frozen_sig.items()] print(f' Freezing sigmoid: {", ".join(parts)}') ## Cache metric configs once. The property reconstructs SimilarityMetric ## objects from dicts on every access; caching avoids that overhead ## inside the objective function which runs thousands of times. _cached_configs = self._metric_configs ## Dict[str, SimilarityMetric] ## Build parameter layout: [power_<m1>, power_<m2>, ..., p_norm] ## This list is immutable and referenced by both bounds and objective self._de_param_layout = [] for name, cfg in _cached_configs.items(): if cfg.optimize_power: self._de_param_layout.append(('power', name)) self._de_param_layout.append(('p_norm', None)) ## If not freezing sigmoid, add sigmoid params too if not freeze_sigmoid: for name, cfg in _cached_configs.items(): if cfg.optimize_sigmoid: self._de_param_layout.append(('sig_mu', name)) self._de_param_layout.append(('sig_b', name)) param_keys = [] for ptype, pname in self._de_param_layout: if ptype == 'power': param_keys.append(f'power_{pname}') elif ptype == 'p_norm': param_keys.append('p_norm') elif ptype == 'sig_mu': param_keys.append(f'sig_{pname}_kwargs_mu') elif ptype == 'sig_b': param_keys.append(f'sig_{pname}_kwargs_b') scipy_bounds = [ tuple(bounds_findParameters[k]) for k in param_keys if k in bounds_findParameters ] ## list of (lo, hi) tuples, one per DE dimension ################################################################ ## Precompute tensors — eliminates all sparse operations from ## the inner loop. ################################################################ ## Extract .data arrays from sparse matrices into contiguous tensors. ## Each tensor has shape (nnz,) — one value per nonzero pair. tensors_full = { name: torch.as_tensor( np.ascontiguousarray(sim.data), dtype=torch.float32, ).clone() for name, sim in self.similarities.items() } ## Dict[str, Tensor(nnz,)] ## Boolean mask for intra-session (known-different) pairs if not hasattr(self, '_intra_mask') or self._intra_mask is None: self._precompute_intra_mask() intra_mask_full = torch.as_tensor(self._intra_mask) ## shape (nnz,), bool ################################################################ ## Helper: build working tensors (with optional subsampling) ################################################################ def _build_working_tensors(resample_seed: Optional[int]): """Return (tensors_dict, intra_mask) after optional subsampling.""" nnz_full = next(iter(tensors_full.values())).shape[0] if subsample_pairs is not None and subsample_pairs < nnz_full: sidx = self._subsample_pairs( n_subsample=subsample_pairs, intra_mask=intra_mask_full, seed=resample_seed if resample_seed is not None else 77777, ) ## Intra pairs come first in sidx. Compute actual intra count n_intra_full = int(intra_mask_full.sum().item()) frac = subsample_pairs / nnz_full n_si = min(max(int(n_intra_full * frac), 100), n_intra_full) im = torch.zeros(sidx.shape[0], dtype=torch.bool) im[:n_si] = True return {name: t[sidx] for name, t in tensors_full.items()}, im else: return dict(tensors_full), intra_mask_full ## Initial working set working_tensors, intra_mask = _build_working_tensors( resample_seed=(seed + 77777) if seed is not None else 77777, ) first_t = next(iter(working_tensors.values())) print( f' Working set: {first_t.shape[0]} pairs ' f'({int(intra_mask.sum().item())} intra, ' f'{first_t.shape[0] - int(intra_mask.sum().item())} inter)' ) if self._verbose and subsample_pairs is not None else None ################################################################ ## Build shared histogram infrastructure from current tensors ################################################################ n_bins_val = self.n_bins edges = torch.linspace(0, 1, n_bins_val + 1, dtype=torch.float32) smooth_window = helpers.make_odd(n_bins_val // 10, mode='up') smoother = helpers.Convolver_1d( kernel=torch.ones(smooth_window), length_x=n_bins_val, pad_mode='same', correct_edge_effects=True, device='cpu', ) ## Mutable containers so resample callback can update them in-place _state = { 'tensors': working_tensors, 'intra_mask': intra_mask, 'generation': 0, } _state['intra_indices'] = torch.where(_state['intra_mask'])[0] _state['n_all'] = first_t.shape[0] _state['n_intra'] = int(_state['intra_mask'].sum().item()) _state['scale_factor'] = _state['n_all'] / max(_state['n_intra'], 1) ################################################################ ## Resample callback — redraws subsampled tensors each generation ################################################################ _generation_counter = [0] def _resample_callback(xk, convergence=None): """Redraw subsample at the start of each generation.""" gen = _generation_counter[0] _generation_counter[0] += 1 new_seed = (seed + gen * 1000 + 99999) if seed is not None else (gen * 1000 + 99999) new_tensors, im_new = _build_working_tensors(resample_seed=new_seed) _state['tensors'] = new_tensors _state['intra_mask'] = im_new _state['intra_indices'] = torch.where(im_new)[0] first_new = next(iter(new_tensors.values())) _state['n_all'] = first_new.shape[0] _state['n_intra'] = int(im_new.sum().item()) _state['scale_factor'] = _state['n_all'] / max(_state['n_intra'], 1) ################################################################ ## Scalar objective — evaluates one parameter vector at a time ################################################################ def objective_scalar(x): ii = _state['intra_indices'] sc = _state['scale_factor'] ## Unpack parameter vector using the frozen layout param_idx = 0 sig_params_live = {} ## for unfrozen sigmoid params ## Parse sigmoid params from x. Must advance param_idx for ## every entry in the layout (power and p_norm included) so ## that sigmoid entries read from the correct positions. for ptype, pname in self._de_param_layout: if ptype in ('power', 'p_norm'): param_idx += 1 ## skip; handled in the loop below elif ptype == 'sig_mu': if pname not in sig_params_live: sig_params_live[pname] = {} sig_params_live[pname]['mu'] = float(x[param_idx]) param_idx += 1 elif ptype == 'sig_b': sig_params_live[pname]['b'] = float(x[param_idx]) param_idx += 1 ## Build activated list for all metrics activated = [] param_idx = 0 for name, cfg in _cached_configs.items(): s_w = _state['tensors'][name] ## Apply sigmoid if configured if cfg.optimize_sigmoid: if _frozen_sig is not None and name in _frozen_sig: mu = _frozen_sig[name]['mu'] b = _frozen_sig[name]['b'] elif name in sig_params_live: mu = sig_params_live[name]['mu'] b = sig_params_live[name]['b'] else: mu, b = 0.0, 1.0 s_w = torch.sigmoid(b * (s_w - mu)) ## Apply power if optimized if cfg.optimize_power: power = float(x[param_idx]) param_idx += 1 s_w = torch.clamp(s_w, min=1e-8).pow(power) else: s_w = torch.clamp(s_w, min=1e-8) activated.append(s_w) ## p-norm is the parameter after all power params p_idx = sum(1 for pt, _ in self._de_param_layout if pt == 'power') p = float(x[p_idx]) p = p if abs(p) > 1e-9 else 1e-9 ## p-norm mixing: distance = 1 - (mean(s_k^p))^(1/p) N = len(activated) running_sum = torch.zeros_like(activated[0]) for a in activated: running_sum += a.pow(p) dist = 1.0 - (running_sum / N).pow(1.0 / p) ## Histogram overlap loss via shared helper loss, _, _ = self._compute_histogram_overlap( distances=dist, intra_indices=ii, edges=edges, smoother=smoother, scale_factor=sc, ) return loss ################################################################ ## Configure and run differential evolution ################################################################ print('Finding mixing parameters using differential evolution...') if self._verbose else None de_kwargs_use = dict(de_kwargs) nnz_full = next(iter(tensors_full.values())).shape[0] ## Always resample each generation when subsampling if subsample_pairs is not None and subsample_pairs < nnz_full: existing_cb = de_kwargs_use.pop('callback', None) def _combined_callback(xk, convergence=None): _resample_callback(xk, convergence) if existing_cb is not None: return existing_cb(xk, convergence) ## propagate stop signal de_kwargs_use['callback'] = _combined_callback ## Coerce seed to int for scipy DE; leave None for random behavior de_seed = int(seed) if seed is not None else None self._de_result = scipy.optimize.differential_evolution( func=objective_scalar, bounds=scipy_bounds, seed=de_seed, **de_kwargs_use, ) ## Extract best parameters from DE result x_best = self._de_result.x self.best_params = {} param_idx = 0 for name, cfg in _cached_configs.items(): if cfg.optimize_power: self.best_params[f'power_{name}'] = float(x_best[param_idx]) param_idx += 1 else: ## Non-optimized metrics get identity power (no transform) self.best_params[f'power_{name}'] = None ## p_norm is after all power params p_idx = sum(1 for pt, _ in self._de_param_layout if pt == 'power') self.best_params['p_norm'] = float(x_best[p_idx]) ## Sigmoid params (frozen or from DE) sig_param_idx = p_idx + 1 ## sigmoid params start after p_norm for name, cfg in _cached_configs.items(): if cfg.optimize_sigmoid: if _frozen_sig is not None and name in _frozen_sig: self.best_params[f'sig_{name}_kwargs'] = { 'mu': _frozen_sig[name]['mu'], 'b': _frozen_sig[name]['b'], } else: ## Extract unfrozen sigmoid params from DE result self.best_params[f'sig_{name}_kwargs'] = { 'mu': float(x_best[sig_param_idx]), 'b': float(x_best[sig_param_idx + 1]), } sig_param_idx += 2 else: ## Non-sigmoid metrics get None (no sigmoid applied) self.best_params[f'sig_{name}_kwargs'] = None self.kwargs_makeConjunctiveDistanceMatrix_best = dict(self.best_params) print( f'Completed DE parameter search. ' f'Best value: {self._de_result.fun:.2f}, ' f'evaluations: {self._de_result.nfev}, ' f'params: {self.best_params}' ) if self._verbose else None return self.kwargs_makeConjunctiveDistanceMatrix_best #################################################################### ## Naive Bayes calibration and sigmoid estimation #################################################################### def _calibrate_feature_1d( self, s_data: torch.Tensor, intra_mask: torch.Tensor, n_bins: int, smoother: 'helpers.Convolver_1d', prob_clip: Tuple[float, float] = (1e-4, 1 - 1e-4), ) -> Dict[str, torch.Tensor]: """ Estimate P(same | s_k) for a single similarity feature using histogram subtraction. Given raw similarity values for all pairs and the intra-session (known-different) subset, estimates the "same" distribution as the residual after subtracting the scaled intra-session distribution from the overall distribution. Monotonicity is enforced (higher similarity → higher P(same)). RH 2025 Args: s_data (torch.Tensor): 1D tensor of raw similarity values, shape ``(nnz,)``. intra_mask (torch.Tensor): Boolean tensor, shape ``(nnz,)``. ``True`` for intra-session (known-different) pairs. n_bins (int): Number of histogram bins. smoother (helpers.Convolver_1d): 1D convolver for smoothing histogram counts. prob_clip (Tuple[float, float]): Clamp P(same) to ``[lo, hi]`` to avoid logit divergence. Returns: (Dict[str, torch.Tensor]): calibration (Dict[str, torch.Tensor]): Dictionary with keys ``'edges'``, ``'counts_all'``, ``'counts_diff'``, ``'counts_diff_smooth'``, ``'counts_same'``, ``'p_same_bins'``. """ n_all = s_data.shape[0] n_intra = int(intra_mask.sum().item()) scale = n_all / n_intra ## Bin edges spanning the data range with small margin lo = float(s_data.min()) - 1e-6 hi = float(s_data.max()) + 1e-6 edges = torch.linspace(lo, hi, n_bins + 1, dtype=torch.float32) ## Histogram all values and intra-session values counts_all, _ = torch.histogram(s_data, edges) counts_intra, _ = torch.histogram(s_data[intra_mask], edges) ## Scale intra counts to estimate the full "different" distribution counts_diff = counts_intra * scale ## "Same" distribution = residual, clamped non-negative. ## Do NOT smooth counts_same — the smoothing kernel bleeds nonzero ## mass into the left tail where the true same-count is zero, creating ## an artificial P(same) floor. Raw residual preserves zeros. counts_same = torch.clamp(counts_all - counts_diff, min=0) ## Smooth only the "different" distribution (estimating the smooth ## population envelope). counts_same stays raw/clamped. counts_diff_smooth = smoother.convolve(counts_diff) ## P(same | bin) = counts_same / (counts_same + counts_diff_smooth) p_same_bins = counts_same / (counts_same + counts_diff_smooth + 1e-10) ## Enforce monotonic increasing via isotonic regression, weighted by ## total evidence per bin. Isotonic regression finds the best ## monotonically-increasing fit — sparse tail bins (with few ## observations) get negligible influence, avoiding artificial ## P(same) floors. ir = sklearn.isotonic.IsotonicRegression( increasing=True, y_min=float(prob_clip[0]), y_max=float(prob_clip[1]), ) evidence_weights = (counts_all + counts_diff_smooth).numpy() evidence_weights = np.maximum(evidence_weights, 1e-10) p_same_np = ir.fit_transform( X=np.arange(n_bins, dtype=np.float64), y=p_same_bins.numpy().astype(np.float64), sample_weight=evidence_weights.astype(np.float64), ) p_same_bins = torch.as_tensor(p_same_np, dtype=torch.float32) ## Store all tensors as numpy so that serializable_dict can ## preserve them (torch tensors are not in the allowed library list). ## Call sites that need torch tensors convert with torch.as_tensor(). return { 'edges': edges.numpy(), 'counts_all': counts_all.numpy(), 'counts_diff': counts_diff.numpy(), 'counts_diff_smooth': counts_diff_smooth.numpy(), 'counts_same': counts_same.numpy(), 'p_same_bins': p_same_bins.numpy(), }
[docs] def make_naive_bayes_distance_matrix( self, n_bins: Optional[int] = None, smoothing_window_bins: Optional[int] = None, prob_clip: Tuple[float, float] = (1e-4, 1 - 1e-4), ) -> Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array, Dict[str, Any]]: """ Compute pairwise distance matrix using independent per-feature calibration combined via naive Bayes. For each similarity feature k (SF, NN, SWT), estimates the posterior ``P(same | s_k)`` from a 1D histogram of similarity values, using the intra-session (known-different) distribution as reference. The per-feature posteriors are combined under conditional independence: .. math:: \\text{logit}(P(\\text{same} | \\mathbf{s})) = \\sum_k \\text{logit}(P(\\text{same} | s_k)) - (K-1) \\cdot \\text{logit}(\\pi) where :math:`\\pi` is the estimated prior P(same) and K is the number of features. **No iterative optimization** — just histogram + lookup. Typically completes in under 1 second even on large datasets. RH 2025 Args: n_bins (Optional[int]): Number of histogram bins per feature. If ``None``, uses ``self.n_bins``. smoothing_window_bins (Optional[int]): Smoothing window. If ``None``, uses ``self.smooth_window``. prob_clip (Tuple[float, float]): Clamp P(same|s_k) to ``[lo, hi]`` before logit. Returns: (Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array, Dict]): dConj (scipy.sparse.csr_array): Distance matrix ``d = 1 - P(same|all)``. sConj (scipy.sparse.csr_array): Similarity matrix ``s = P(same|all)``. calibrations (Dict[str, Any]): Diagnostic dict with per-feature calibrations, prior, and combined P(same). """ self.params['make_naive_bayes_distance_matrix'] = self._locals_to_params( locals_dict=locals(), keys=['n_bins', 'smoothing_window_bins', 'prob_clip'], ) n_bins = self.n_bins if n_bins is None else n_bins smooth_window = self.smooth_window if smoothing_window_bins is None else smoothing_window_bins print('Computing naive Bayes distance matrix...') if self._verbose else None ## Precompute intra-session mask if not hasattr(self, '_intra_mask') or self._intra_mask is None: self._precompute_intra_mask() intra_mask = torch.as_tensor(self._intra_mask) ## Build smoother (shared across features) smoother = helpers.Convolver_1d( kernel=torch.ones(smooth_window), length_x=n_bins, pad_mode='same', correct_edge_effects=True, device='cpu', ) ## Features to calibrate: iterate over all similarity metrics features_raw = { name: torch.as_tensor(sim.data, dtype=torch.float32) for name, sim in self.similarities.items() } calibrations = {'features': {}} nnz = self._s_sparsity.nnz logit_sum = torch.zeros(nnz, dtype=torch.float32) ## Calibrate each feature independently for name, s_data in features_raw.items(): cal = self._calibrate_feature_1d( s_data=s_data, intra_mask=intra_mask, n_bins=n_bins, smoother=smoother, prob_clip=prob_clip, ) ## Look up P(same) for each pair from its histogram bin. ## cal values are numpy arrays; convert to torch for the lookup ## then store result as numpy for serialization safety. edges_t = torch.as_tensor(cal['edges']) p_same_bins_t = torch.as_tensor(cal['p_same_bins']) bin_idx = torch.searchsorted( edges_t[1:-1].contiguous(), s_data, ) bin_idx = torch.clamp(bin_idx, 0, n_bins - 1) p_same_per_pair = p_same_bins_t[bin_idx] ## torch, shape (nnz,) ## Accumulate logit for naive Bayes combination logit_p = torch.log(p_same_per_pair / (1.0 - p_same_per_pair)) logit_sum += logit_p ## Store as numpy so serializable_dict preserves it cal['p_same_per_pair'] = p_same_per_pair.numpy() calibrations['features'][name] = cal print( f' {name}: P(same) range ' f'[{p_same_per_pair.min():.4f}, {p_same_per_pair.max():.4f}], ' f'mean={p_same_per_pair.mean():.4f}' ) if self._verbose else None ## Estimate prior P(same) K = len(features_raw) prior_estimates = [] for cal in calibrations['features'].values(): total_same = cal['counts_same'].sum().item() total = total_same + cal['counts_diff_smooth'].sum().item() if total > 0: prior_estimates.append(total_same / total) prior = float(np.mean(prior_estimates)) if prior_estimates else 0.5 prior = float(np.clip(prior, prob_clip[0], prob_clip[1])) calibrations['prior'] = prior ## Naive Bayes log-odds combination logit_prior = float(np.log(prior / (1.0 - prior))) logit_combined = logit_sum - (K - 1) * logit_prior ## Convert back to probability p_same_combined = torch.sigmoid(logit_combined).numpy() calibrations['p_same_combined'] = p_same_combined ## Build sparse similarity and distance matrices sConj = self._s_sparsity.copy() sConj.data = p_same_combined.astype(np.float64) dConj = sConj.copy() dConj.data = 1.0 - dConj.data ## Store for downstream use self.dConj = dConj self.sConj = sConj self.calibrations_naive_bayes = calibrations print( f' Combined P(same): mean={p_same_combined.mean():.4f}, ' f'prior={prior:.4f}, ' f'P(same)>0.5: {(p_same_combined > 0.5).sum()}/{nnz} ' f'({(p_same_combined > 0.5).mean() * 100:.1f}%)' ) if self._verbose else None return dConj, sConj, calibrations
def _estimate_sigmoid_params(self) -> Dict[str, Dict[str, float]]: """ Estimate sigmoid parameters (mu, b) for NN and SWT from NB calibration curves using Fisher's linear discriminant. For each feature, finds the sigmoid ``sigma(b * (s - mu))`` that best separates "same" and "different" distributions in the calibration histogram. Uses a grid search over (mu, b) to maximize the Fisher discriminant ratio in sigmoid-transformed space. Requires :meth:`make_naive_bayes_distance_matrix` to have been called first. RH 2025 Returns: (Dict[str, Dict[str, float]]): sigmoid_params (Dict[str, Dict[str, float]]): Mapping from feature name to ``{'mu': float, 'b': float}``. """ assert hasattr(self, 'calibrations_naive_bayes') and self.calibrations_naive_bayes is not None, ( "make_naive_bayes_distance_matrix() must be called before " "_estimate_sigmoid_params()." ) result = {} for name, cfg in self._metric_configs.items(): if not cfg.optimize_sigmoid: continue cal = self.calibrations_naive_bayes['features'][name] ## cal values are numpy arrays (stored that way for serialization) edges = np.asarray(cal['edges']) counts_same = np.asarray(cal['counts_same']) counts_diff = np.asarray(cal['counts_diff_smooth']) ## Bin centers in similarity space, shape (n_bins,) centers_np = (edges[:-1] + edges[1:]) / 2.0 ## Normalized distribution weights, shape (n_bins,) w_same = counts_same / (counts_same.sum() + 1e-10) w_diff = counts_diff / (counts_diff.sum() + 1e-10) ## Vectorized grid search over (mu, b) to maximize Fisher ## discriminant in sigmoid-transformed space. ## Grid shapes: mu (M,), b (B,) → sig_vals (M, B, n_bins) mu_grid = np.linspace( float(centers_np.min()), float(centers_np.max()), 50, ) b_grid = np.linspace(0.5, 10.0, 30) ## Broadcasting: (M,1,1) * ((1,1,n_bins) - (M,1,1)) sig_vals = 1.0 / (1.0 + np.exp( -b_grid[None, :, None] * (centers_np[None, None, :] - mu_grid[:, None, None]) )) ## shape (M, B, n_bins) ## Weighted moments in sigmoid-transformed space mu_same_sig = np.sum(w_same[None, None, :] * sig_vals, axis=2) ## (M, B) mu_diff_sig = np.sum(w_diff[None, None, :] * sig_vals, axis=2) ## (M, B) var_same_sig = np.sum(w_same[None, None, :] * (sig_vals - mu_same_sig[:, :, None]) ** 2, axis=2) var_diff_sig = np.sum(w_diff[None, None, :] * (sig_vals - mu_diff_sig[:, :, None]) ** 2, axis=2) ## Fisher discriminant ratio, shape (M, B) denom = var_same_sig + var_diff_sig + 1e-12 fisher_grid = (mu_same_sig - mu_diff_sig) ** 2 / denom ## Find best (mu, b) best_idx = np.unravel_index(fisher_grid.argmax(), fisher_grid.shape) best_mu = float(mu_grid[best_idx[0]]) best_b = float(b_grid[best_idx[1]]) best_fisher = float(fisher_grid[best_idx]) result[name] = {'mu': best_mu, 'b': best_b} print( f' Sigmoid estimate for {name}: ' f'mu={best_mu:.4f}, b={best_b:.2f} ' f'(Fisher={best_fisher:.4f})' ) if self._verbose else None return result
[docs] def make_pruned_similarity_graphs( self, convert_to_probability: bool = False, stringency: float = 1.0, mixing_params: Optional[Dict] = None, d_cutoff: Optional[float] = None, ) -> None: """ Constructs pruned similarity graphs. RH 2023 Args: convert_to_probability (bool): Whether to convert the distance and similarity graphs to probability, *p(different)* and *p(same)*, respectively. (Default is ``False``) stringency (float): Modifies the threshold for pruning the distance matrix. A higher value results in less pruning, a lower value leads to more pruning. This value is multiplied by the inferred threshold to generate a new one. (Default is *1.0*) mixing_params (Optional[Dict]): Mixing parameters for ``self.make_conjunctive_distance_matrix``. If ``None``, the best parameters found using ``self.find_optimal_parameters`` are used. Use ``'precomputed'`` to use a previously stored ``self.dConj``. (Default is ``None``) d_cutoff (Optional[float]): The cutoff distance for pruning the distance matrix. If ``None``, then the optimal cutoff distance is inferred. (Default is ``None``) """ ## Store parameter (but not data) args as attributes self.params['make_pruned_similarity_graphs'] = self._locals_to_params( locals_dict=locals(), keys=[ 'convert_to_probability', 'stringency', 'mixing_params', ], ) ## If 'precomputed', use self.dConj/sConj set by a prior call ## (e.g. make_naive_bayes_distance_matrix). Otherwise, compute ## the conjunctive distance matrix from mixing parameters. if mixing_params == 'precomputed': assert hasattr(self, 'dConj') and self.dConj is not None, ( "mixing_params='precomputed' requires self.dConj to be set " "(call make_naive_bayes_distance_matrix first)." ) elif mixing_params is None: if hasattr(self, 'kwargs_makeConjunctiveDistanceMatrix_best'): mixing_params = self.kwargs_makeConjunctiveDistanceMatrix_best else: mixing_params = {'p_norm': -4.0} for name in self.similarities: mixing_params[f'power_{name}'] = 1.0 ## identity (no transform) mixing_params[f'sig_{name}_kwargs'] = {'mu': 0.5, 'b': 0.5} warnings.warn(f'No mixing_params provided. Using defaults: {mixing_params}') if mixing_params != 'precomputed': self.dConj, self.sConj, self._activated_data = self.make_conjunctive_distance_matrix( similarities=self.similarities, mixing_params=mixing_params, ) dens_same_crop, dens_same, dens_diff, dens_all, edges, d_crossover = self._separate_diffSame_distributions(self.dConj) if convert_to_probability: ## convert into probabilities ### first smooth dens_diff. (dens_same is already smoothed) dens_diff_smooth = self._fn_smooth(dens_diff) ### second, compute the probability of each bin prob_same = (dens_same / (dens_same + dens_diff_smooth)).numpy() ### force to be monotonic decreasing prob_same = np.maximum.accumulate(prob_same[::-1])[::-1] ### third, append 0 to the end prob_same = np.append(prob_same, 0) ### third, convert self.dConj to probabilities using interpolation import scipy.interpolate fn_interp = scipy.interpolate.interp1d(edges, prob_same, kind='linear', fill_value='extrapolate') self.sConj.data = fn_interp(self.dConj.data) self.dConj.data = 1 - self.sConj.data d_crossover = 1 - fn_interp(d_crossover) self.distributions_mixing = { 'mixing_params': mixing_params, 'dens_same_crop': dens_same_crop, 'dens_same': dens_same, 'dens_diff': dens_diff, 'dens_all': dens_all, 'edges': edges, 'd_crossover': d_crossover, } ## Copy all similarity matrices for pruning sims_copy = {name: s.copy() for name, s in self.similarities.items()} ssesh = self.s_sesh.copy() min_d = np.nanmin(self.dConj.data) if d_cutoff is None: range_d = d_crossover - min_d self.d_cutoff = min_d + range_d * stringency else: self.d_cutoff = d_cutoff print(f'Pruning similarity graphs with d_cutoff = {self.d_cutoff}...') if self._verbose else None self.graph_pruned = self.dConj.copy() self.graph_pruned.data = self.graph_pruned.data < self.d_cutoff self.graph_pruned.eliminate_zeros() def prune(s, graph_pruned): import scipy.sparse if s is None: return None s_pruned = scipy.sparse.csr_array(graph_pruned.shape, dtype=np.float32) s_pruned[graph_pruned] = s[graph_pruned] s_pruned = s_pruned.tocsr() return s_pruned ## Prune all similarity matrices self.similarities_pruned = { name: prune(s, self.graph_pruned) for name, s in sims_copy.items() } self.s_sesh_pruned = prune(ssesh, self.graph_pruned) self.dConj_pruned = prune(self.dConj, self.graph_pruned) self.sConj_pruned = prune(self.sConj, self.graph_pruned)
[docs] def apply_weighted_jaccard( self, s_conj: Optional[scipy.sparse.csr_array] = None, d_conj: Optional[scipy.sparse.csr_array] = None, alpha: float = 1.0, ) -> Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array]: """ Apply weighted Jaccard preprocessing to a similarity graph. Returns new similarity and distance matrices without modifying ``self``. Can be called in two ways: 1. **From stored state** (no args): uses ``self.sConj_pruned`` from ``make_pruned_similarity_graphs()``. 2. **Purely functional**: pass ``s_conj`` or ``d_conj`` directly. Exactly one must be provided; the other is derived as ``1 - x``. The weighted Jaccard replaces each pairwise similarity with a measure of shared neighborhood structure: .. math:: J_w(i,j) = \\frac{\\sum_k \\min(s_{ik}, s_{jk})} {\\sum_k \\max(s_{ik}, s_{jk})} This amplifies within-cluster similarity and suppresses cross-cluster noise. See :func:`weighted_jaccard_similarity` for details. Args: s_conj (Optional[scipy.sparse.csr_array]): Similarity matrix to transform. If ``None``, uses ``self.sConj_pruned``. Mutually exclusive with ``d_conj``. d_conj (Optional[scipy.sparse.csr_array]): Distance matrix to transform (converted to similarity via ``1 - d``). If ``None``, uses ``self.sConj_pruned``. Mutually exclusive with ``s_conj``. alpha (float): Blending weight between Jaccard and original similarity. ``alpha=1.0`` uses pure Jaccard (default). ``alpha=0.0`` returns copies of the originals. Values in between give a linear blend: ``s_final = alpha * s_jaccard + (1-alpha) * s_original``. Returns: (Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array]): sConj_jaccard (scipy.sparse.csr_array): Jaccard-refined similarity matrix. Same sparsity as input. dConj_jaccard (scipy.sparse.csr_array): Corresponding distance matrix (``1 - sConj_jaccard``). """ assert not (s_conj is not None and d_conj is not None), ( "Provide s_conj or d_conj, not both." ) assert 0.0 <= alpha <= 1.0, f"alpha must be in [0, 1], got {alpha}" ## Resolve input similarity matrix if s_conj is not None: s = s_conj elif d_conj is not None: s = d_conj.copy() s.data = 1.0 - s.data else: assert hasattr(self, 'sConj_pruned') and self.sConj_pruned is not None, ( "No input provided and self.sConj_pruned is not set. " "Either pass s_conj/d_conj or call make_pruned_similarity_graphs() first." ) s = self.sConj_pruned if alpha == 0.0: s_out = s.copy() d_out = s.copy() d_out.data = 1.0 - d_out.data return s_out, d_out s_jaccard = weighted_jaccard_similarity(s) if alpha < 1.0: ## Blend: s_final = alpha * s_jaccard + (1 - alpha) * s_original s_jaccard.data = alpha * s_jaccard.data + (1.0 - alpha) * s.data d_jaccard = s_jaccard.copy() d_jaccard.data = 1.0 - d_jaccard.data return s_jaccard, d_jaccard
[docs] def fit( self, d_conj: scipy.sparse.csr_array, session_bool: np.ndarray, min_cluster_size: int = 2, max_cluster_size: Optional[int] = None, min_samples: Optional[int] = None, n_iter_violationCorrection: int = 5, cluster_selection_method: str = 'leaf', cluster_selection_persistence: float = 0.0, d_clusterMerge: Optional[float] = None, alpha: float = 0.999, split_intraSession_clusters: bool = True, discard_failed_pruning: bool = True, n_steps_clusterSplit: int = 100, backend: str = 'fast_hdbscan', algorithm: str = 'kruskal', rescue_noise: bool = True, ) -> np.ndarray: """ Fits clustering using HDBSCAN with same-session constraint enforcement. By default (``backend='fast_hdbscan'``), uses ``fast_hdbscan.HDBSCAN`` with group-label cannot-link constraints. Each ROI's session index is passed as a group label; same-session ROIs cannot co-cluster. This uses O(N) memory (an int32 vector) instead of the O(N^2) sparse cannot-link matrix, and O(1) per-merge conflict checks via bitmask. Transitive constraints are handled correctly: if components A and B are merged and both contain session-0 ROIs, the merge is blocked. Set ``backend='legacy'`` to use the original ``hdbscan.HDBSCAN`` with iterative violation correction via dendrogram walk-back. RH 2023 / 2025 Args: d_conj (scipy.sparse.csr_array): Conjunctive distance matrix. session_bool (np.ndarray): Boolean array indicating which ROIs belong to which session. Shape: *(n_rois, n_sessions)* min_cluster_size (int): Minimum cluster size to be considered a cluster. Can be 'all'. (Default is *2*) max_cluster_size (Optional[int]): Maximum cluster size. Clusters larger than this are split. If ``None``, defaults to ``n_sessions`` (one ROI per session), which is the natural constraint for neuron tracking. Set to a larger value to allow clusters spanning a subset of sessions. (Default is ``None``) min_samples (Optional[int]): Number of neighbors a point needs to be considered "core" (non-noise) by HDBSCAN. Controls the density threshold independently of ``min_cluster_size``. Lower values → fewer noise points. If ``None``, defaults to ``min_cluster_size`` (HDBSCAN default behavior). (Default is ``None``) n_iter_violationCorrection (int): Number of iterations to correct for clusters with multiple ROIs per session. Only used with ``backend='legacy'``. (Default is *5*) cluster_selection_method (str): Cluster selection method. Either ``'leaf'`` or ``'eom'``. 'leaf' leans towards smaller clusters, 'eom' towards larger clusters. (Default is ``'leaf'``) cluster_selection_persistence (float): Minimum stability (persistence) a cluster must have to survive selection. Clusters below this threshold are folded into their parent. Higher values → fewer but more stable clusters. Only used with ``backend='fast_hdbscan'``. (Default is *0.0*) d_clusterMerge (Optional[float]): Distance threshold for merging clusters (``cluster_selection_epsilon`` in HDBSCAN). Clusters separated by less than this distance are merged. If ``None``, defaults to ``self.d_cutoff`` (the pruning threshold from ``make_pruned_similarity_graphs``), which is the inferred same/different decision boundary. Falls back to ``mean + 1*std`` of the distance data if ``d_cutoff`` is not available. (Default is ``None``) alpha (float): Alpha value. Only used with ``backend='legacy'``. (Default is *0.999*) split_intraSession_clusters (bool): If ``True``, clusters containing ROIs from multiple sessions will be split. Only used with ``backend='legacy'``. (Default is ``True``) discard_failed_pruning (bool): If ``True``, clusters failing to prune are set to -1. Only used with ``backend='legacy'``. (Default is ``True``) n_steps_clusterSplit (int): Number of steps for splitting clusters with multiple ROIs from the same session. Only used with ``backend='legacy'``. (Default is *100*) backend (str): Which HDBSCAN implementation to use: \n * ``'fast_hdbscan'``: Use fast_hdbscan with native cannot-link constraints (default). * ``'legacy'``: Use legacy hdbscan with iterative violation correction. \n (Default is ``'fast_hdbscan'``) algorithm (str): Algorithm for fast_hdbscan MST construction. Only used with ``backend='fast_hdbscan'``: \n * ``'kruskal'``: Kruskal DSU on full CSR edge list. Supports cannot-link with any metric. * ``'boruvka'``: Boruvka parallel MST. Supports cannot-link only with ``metric='precomputed'``. \n (Default is ``'kruskal'``) rescue_noise (bool): If ``True``, run a post-HDBSCAN noise rescue pass that assigns noise ROIs to nearby clusters (or nucleates new small clusters) using a Kruskal-style sorted-edge traversal with DSU bitmask cannot-link constraints. Only used with ``backend='fast_hdbscan'``. (Default is ``True``) Returns: (np.ndarray): labels (np.ndarray): Cluster labels for each ROI, shape: *(n_rois_total)* """ ## Store parameter (but not data) args as attributes self.params['fit'] = self._locals_to_params( locals_dict=locals(), keys=[ 'min_cluster_size', 'max_cluster_size', 'min_samples', 'n_iter_violationCorrection', 'cluster_selection_method', 'cluster_selection_persistence', 'd_clusterMerge', 'alpha', 'split_intraSession_clusters', 'discard_failed_pruning', 'n_steps_clusterSplit', 'backend', 'algorithm', 'rescue_noise', ], ) ## Resolve d_clusterMerge default: use d_cutoff from pruning if available if d_clusterMerge is None: if hasattr(self, 'd_cutoff') and self.d_cutoff is not None: d_clusterMerge = float(self.d_cutoff) ## Otherwise each backend falls back to mean + 1*std heuristic if backend == 'fast_hdbscan': return self._fit_fast_hdbscan( d_conj=d_conj, session_bool=session_bool, min_cluster_size=min_cluster_size, max_cluster_size=max_cluster_size, min_samples=min_samples, cluster_selection_method=cluster_selection_method, cluster_selection_persistence=cluster_selection_persistence, d_clusterMerge=d_clusterMerge, algorithm=algorithm, rescue_noise=rescue_noise, ) elif backend == 'legacy': return self._fit_legacy_hdbscan( d_conj=d_conj, session_bool=session_bool, min_cluster_size=min_cluster_size, max_cluster_size=max_cluster_size, min_samples=min_samples, n_iter_violationCorrection=n_iter_violationCorrection, cluster_selection_method=cluster_selection_method, d_clusterMerge=d_clusterMerge, alpha=alpha, split_intraSession_clusters=split_intraSession_clusters, discard_failed_pruning=discard_failed_pruning, n_steps_clusterSplit=n_steps_clusterSplit, ) else: raise ValueError( f"backend must be 'fast_hdbscan' or 'legacy'. Got: {backend!r}" )
def _fit_fast_hdbscan( self, d_conj: scipy.sparse.csr_array, session_bool: np.ndarray, min_cluster_size: int = 2, max_cluster_size: Optional[int] = None, min_samples: Optional[int] = None, cluster_selection_method: str = 'leaf', cluster_selection_persistence: float = 0.0, d_clusterMerge: Optional[float] = None, algorithm: str = 'kruskal', rescue_noise: bool = True, ) -> np.ndarray: """ Fit clustering using fast_hdbscan with group-label cannot-link constraints, optionally followed by a noise rescue pass. Each ROI is assigned its session index as a group label. Samples sharing the same group label (i.e., same session) cannot co-cluster. This uses a bitmask per DSU component for O(1) conflict checks, replacing the previous O(N^2) sparse cannot-link matrix. RH 2025 Args: d_conj (scipy.sparse.csr_array): Conjunctive distance matrix. Shape: *(n_rois, n_rois)*. session_bool (np.ndarray): Boolean array indicating which ROIs belong to which session. Shape: *(n_rois, n_sessions)*. Each row should contain exactly one ``True``. min_cluster_size (int): Minimum cluster size. (Default is *2*) max_cluster_size (Optional[int]): Maximum cluster size. If ``None``, defaults to ``n_sessions``. (Default is ``None``) min_samples (Optional[int]): Number of neighbors for core-point density estimation. If ``None``, defaults to ``min_cluster_size``. (Default is ``None``) cluster_selection_method (str): ``'leaf'`` or ``'eom'``. (Default is ``'leaf'``) cluster_selection_persistence (float): Minimum cluster stability threshold. (Default is *0.0*) d_clusterMerge (Optional[float]): Distance threshold for merging clusters. If ``None``, computed as mean + 1*std of the distance data. (Default is ``None``) algorithm (str): MST construction algorithm. ``'kruskal'`` or ``'boruvka'``. (Default is ``'kruskal'``) rescue_noise (bool): If ``True``, run a post-HDBSCAN noise rescue pass. (Default is ``True``) Returns: (np.ndarray): labels (np.ndarray): Cluster labels for each ROI, shape: *(n_rois_total)*. """ ## Mask to inter-session pairs only (same as legacy) d = d_conj.copy().multiply(self.s_sesh) if d.nnz == 0: print('No edges in graph. Returning all -1 labels.') if self._verbose else None self.labels = np.ones(d.shape[0], dtype=int) * -1 return self.labels n_sessions = session_bool.shape[1] if min_cluster_size == 'all': min_cluster_size = n_sessions print(f'Setting min_cluster_size to {min_cluster_size} (all ROIs in a session)') if self._verbose else None ## Resolve max_cluster_size default: one ROI per session if max_cluster_size is None: max_cluster_size = n_sessions ## Auto-estimate d_clusterMerge from distance statistics d_clusterMerge = float(np.mean(d.data) + 1 * np.std(d.data)) if d_clusterMerge is None else float(d_clusterMerge) ## Build group-label cannot-link vector: explicit session index per ROI. ## This avoids relying on ROIs being stored in contiguous session ## blocks; fast_hdbscan expects one group label per sample. n_sessions_per_roi = np.asarray(session_bool.sum(axis=1)).ravel() assert np.all(n_sessions_per_roi == 1), ( "session_bool must contain exactly one True per ROI when using " "fast_hdbscan cannot_link_groups" ) cannot_link_groups = np.asarray(np.argmax(session_bool, axis=1), dtype=np.int32) print('Fitting with fast_hdbscan (group-label cannot-link constraints)') if self._verbose else None ## Run fast_hdbscan with group-label cannot-link constraints import fast_hdbscan self.hdbs = fast_hdbscan.HDBSCAN( min_cluster_size=min_cluster_size, min_samples=min_samples, cluster_selection_epsilon=d_clusterMerge, cluster_selection_persistence=cluster_selection_persistence, max_cluster_size=max_cluster_size, metric='precomputed', algorithm=algorithm, cannot_link_groups=cannot_link_groups, cluster_selection_method=cluster_selection_method, ) ## fast_hdbscan's Kruskal algorithm handles disconnected components ## natively (they become separate clusters), so the ## attach_fully_connected_node hack used by the legacy backend is ## not needed here. self.hdbs.fit(d) labels = self.hdbs.labels_.copy() self._fit_used_fully_connected_node = False ## Report violations (should be zero with group-label cannot-link) unique_labels = np.unique(labels) unique_labels = unique_labels[unique_labels > -1] violations_labels = unique_labels[np.array([ np.unique(cannot_link_groups[mask := (labels == u)]).size < mask.sum() for u in unique_labels ], dtype=bool)] ## A cluster violates if it has duplicate session indices. ## Equivalently: n_unique_sessions < n_rois_in_cluster. self.violations_labels = violations_labels if self._verbose: n_violations = len(violations_labels) print(f'Session violations after fast_hdbscan: {n_violations} clusters, d_clusterMerge={d_clusterMerge:.2f}') ## Phase 2: noise rescue — assign noise ROIs to nearby clusters or ## nucleate new small clusters via Kruskal DSU with bitmask constraints. ## Operates on the same inter-session-masked distance matrix `d` and ## uses d_clusterMerge as the edge distance cutoff. if rescue_noise and np.any(labels == -1): labels = self.rescue_noise( d_conj=d, labels=labels, session_bool=session_bool, d_cutoff=d_clusterMerge, ) ## Post-processing: squeeze labels, remove too-small clusters labels = helpers.squeeze_integers(labels) ## Set clusters below min_cluster_size to noise. This catches both ## singletons and small nucleated clusters from noise rescue. u, c = np.unique(labels, return_counts=True) labels[np.isin(labels, u[c < min_cluster_size])] = -1 labels = helpers.squeeze_integers(labels) self.labels = labels return self.labels
[docs] def rescue_noise( self, d_conj: scipy.sparse.csr_array, labels: np.ndarray, session_bool: np.ndarray, d_cutoff: float, ) -> np.ndarray: """ Assign noise ROIs (``label == -1``) to nearby clusters or nucleate new small clusters, respecting same-session cannot-link constraints. Uses a Kruskal-style sorted-edge traversal with DSU and bitmask constraints (Phase 2 of two-phase clustering). See :func:`noise_rescue_kruskal` for algorithm details. Non-mutating: does not modify ``self.labels`` or any stored state. Args: d_conj (scipy.sparse.csr_array): Inter-session masked distance matrix. Shape: *(n_rois, n_rois)*. labels (np.ndarray): Phase 1 cluster labels from HDBSCAN. Shape: *(n_rois,)*. ``-1`` = noise. session_bool (np.ndarray): Boolean array, shape *(n_rois, n_sessions)*. Each row has exactly one ``True``. d_cutoff (float): Maximum edge distance to accept for noise rescue. Returns: (np.ndarray): new_labels (np.ndarray): Updated cluster labels after noise rescue. Shape: *(n_rois,)*. """ n_sessions = session_bool.shape[1] group_labels = np.asarray(np.argmax(session_bool, axis=1), dtype=np.int32) n_noise_before = np.sum(labels == -1) new_labels = noise_rescue_kruskal( d_conj=d_conj, labels=labels, group_labels=group_labels, n_groups=n_sessions, d_cutoff=d_cutoff, ) n_noise_after = np.sum(new_labels == -1) n_rescued = n_noise_before - n_noise_after if self._verbose: print(f'Noise rescue: {n_rescued}/{n_noise_before} noise ROIs rescued, {n_noise_after} remain noise') return new_labels
def _fit_legacy_hdbscan( self, d_conj: scipy.sparse.csr_array, session_bool: np.ndarray, min_cluster_size: int = 2, max_cluster_size: Optional[int] = None, min_samples: Optional[int] = None, n_iter_violationCorrection: int = 5, cluster_selection_method: str = 'leaf', d_clusterMerge: Optional[float] = None, alpha: float = 0.999, split_intraSession_clusters: bool = True, discard_failed_pruning: bool = True, n_steps_clusterSplit: int = 100, ) -> np.ndarray: """ Legacy clustering using hdbscan==0.8.41 with iterative violation correction. Preserved for backward compatibility. Fits clustering using a modified HDBSCAN clustering algorithm. The approach is to use HDBSCAN but avoid having clusters with multiple ROIs from the same session. This is achieved by repeating three steps: \n 1. Fit HDBSCAN to the data. 2. Identify clusters that have multiple ROIs from the same session and walk back down the dendrogram until those clusters are split up into non-violating clusters. 3. Disconnect graph edges between ROIs within each new cluster and all other ROIs outside the cluster that are from the same session. \n RH 2023 Args: d_conj (scipy.sparse.csr_array): Conjunctive distance matrix. session_bool (np.ndarray): Boolean array indicating which ROIs belong to which session. Shape: *(n_rois, n_sessions)* min_cluster_size (int): Minimum cluster size to be considered a cluster. Can be 'all'. (Default is *2*) max_cluster_size (Optional[int]): Maximum cluster size. If ``None``, defaults to ``n_sessions``. (Default is ``None``) min_samples (Optional[int]): Number of neighbors for core-point density estimation. If ``None``, defaults to ``min_cluster_size``. (Default is ``None``) n_iter_violationCorrection (int): Number of iterations to correct for clusters with multiple ROIs per session. (Default is *5*) cluster_selection_method (str): Cluster selection method. Either ``'leaf'`` or ``'eom'``. (Default is ``'leaf'``) d_clusterMerge (Optional[float]): Distance threshold for merging clusters. If ``None``, the distance is calculated as the mean + 1*std of the conjunctive distances. (Default is ``None``) alpha (float): Alpha value. Smaller values result in more clusters. (Default is *0.999*) split_intraSession_clusters (bool): If ``True``, clusters containing ROIs from multiple sessions will be split. (Default is ``True``) discard_failed_pruning (bool): If ``True``, clusters failing to prune are set to -1. (Default is ``True``) n_steps_clusterSplit (int): Number of steps for splitting clusters. (Default is *100*) Returns: (np.ndarray): labels (np.ndarray): Cluster labels for each ROI, shape: *(n_rois_total)* """ import hdbscan d = d_conj.copy().multiply(self.s_sesh) if d.nnz == 0: print('No edges in graph. Returning all -1 labels.') if self._verbose else None self.labels = np.ones(d.shape[0], dtype=int) * -1 return self.labels n_sessions = session_bool.shape[1] if min_cluster_size == 'all': min_cluster_size = n_sessions print(f'Setting min_cluster_size to {min_cluster_size} (all ROIs in a session)') if self._verbose else None ## Resolve max_cluster_size default: one ROI per session if max_cluster_size is None: max_cluster_size = n_sessions print('Fitting with HDBSCAN and splitting clusters with multiple ROIs per session') if self._verbose else None for ii in tqdm(range(n_iter_violationCorrection)): ## Prep parameters for splitting clusters d_clusterMerge = float(np.mean(d.data) + 1*np.std(d.data)) if d_clusterMerge is None else float(d_clusterMerge) n_steps_clusterSplit = int(n_steps_clusterSplit) max_dist=(d.max() - d.min()) * 1000 self.hdbs = hdbscan.HDBSCAN( min_cluster_size=min_cluster_size, min_samples=min_samples, cluster_selection_epsilon=d_clusterMerge, max_cluster_size=max_cluster_size, metric='precomputed', alpha=alpha, algorithm='generic', cluster_selection_method=cluster_selection_method, max_dist=max_dist, ) self.hdbs.fit(attach_fully_connected_node( d, dist_fullyConnectedNode=max_dist, n_nodes=1, )) labels = self.hdbs.labels_[:-1] self.labels = labels self._fit_used_fully_connected_node = True print(f'Initial number of violating clusters: {len(np.unique(labels)[np.array([(session_bool[labels==u].sum(0)>1).sum().item() for u in np.unique(labels)]) > 0])}, d_clusterMerge={d_clusterMerge:.2f}') if self._verbose else None ## Split up labels with multiple ROIs per session ## The below code is a bit of a mess, but it works. ## It works by iteratively reducing the cutoff distance ## and splitting up violating clusters until there are ## no more violations. if split_intraSession_clusters: labels = labels.copy() sb_t = torch.as_tensor(session_bool, dtype=torch.float32) ## (n_rois, n_sessions) n = len(d_conj.data) dcd = np.sort(d_conj.data) cuts_all = np.sort(np.unique([dcd[0]/2] + [dcd[int(n*ii)] for ii in np.linspace(0., 1., num=n_steps_clusterSplit, endpoint=False)[::-1]] + [dcd[-1]]))[::-1] for d_cut in cuts_all: labels_t = torch.as_tensor(labels, dtype=torch.int64) lab_u_t, lab_u_idx_t = torch.unique(labels_t, return_inverse=True) # (n_clusters,), (n_rois,) lab_oneHot_t = helpers.idx_to_oneHot(lab_u_idx_t, dtype=torch.float32) violations_labels = lab_u_t[((sb_t.T @ lab_oneHot_t) > 1.5).sum(0) > 0] violations_labels = violations_labels[violations_labels > -1] if len(violations_labels) == 0: break for l in violations_labels: idx = np.where(labels==l)[0] if d[idx][:,idx].nnz == 0: labels[idx] = -1 labels_new = self.hdbs.single_linkage_tree_.get_clusters( cut_distance=d_cut, min_cluster_size=min_cluster_size, )[:-1] idx_toUpdate = np.isin(labels, violations_labels) labels[idx_toUpdate] = labels_new[idx_toUpdate] + labels.max() + 5 labels[(labels_new == -1) * idx_toUpdate] = -1 if discard_failed_pruning: labels[idx_toUpdate] = -1 if ii < n_iter_violationCorrection - 1: ## Find sessions represented in each cluster and set distances to ROIs in those sessions to 1. d = d.tocsr() for ii, l in enumerate(np.unique(labels)): if l == -1: continue idx = np.where(labels==l)[0] d_sub = d[idx][:,idx] idx_grid = np.meshgrid(idx, idx) ## set distances of ROIs from same session to 0 sesh_to_exclude = 1 - (session_bool @ (session_bool[idx].max(0))) ## make a mask of sessions that are not represented in the cluster d[idx,:] = d[idx,:].multiply(sesh_to_exclude[None,:]) ## set distances to ROIs from sessions represented in the cluster to 1 d[:,idx] = d[:,idx].multiply(sesh_to_exclude[:,None]) ## set distances to ROIs from sessions represented in the cluster to 1 d[idx_grid[0], idx_grid[1]] = d_sub ## undo the above for ROIs in the cluster d = d.tocsr() d.eliminate_zeros() ## remove zeros labels = helpers.squeeze_integers(labels) violations_labels = np.unique(labels)[np.array([(session_bool[labels==u].sum(0)>1).sum().item() for u in np.unique(labels)]) > 0] violations_labels = violations_labels[violations_labels > -1] self.violations_labels = violations_labels ## Set clusters with too few ROIs to -1 u, c = np.unique(labels, return_counts=True) labels[np.isin(labels, u[c<2])] = -1 labels = helpers.squeeze_integers(labels) self.labels = labels return self.labels
[docs] def fit_sequentialHungarian( self, d_conj: scipy.sparse.csr_array, session_bool: np.ndarray, thresh_cost: float = 0.95, ) -> np.ndarray: """ Applies CaImAn's method for clustering. For further details, please refer to: * [CaImAn's paper](https://elifesciences.org/articles/38173#s4) * [CaImAn's repository](https://github.com/flatironinstitute/CaImAn) * [Relevant script in CaImAn's repository](https://github.com/flatironinstitute/CaImAn/blob/master/caiman/base/rois.py) Args: d_conj (scipy.sparse.csr_array): Distance matrix. Shape: *(n_rois, n_rois)* session_bool (np.ndarray): Boolean array indicating which ROIs are in which sessions. Shape: *(n_rois, n_sessions)* thresh_cost (float): Threshold below which ROI pairs are considered potential matches. (Default is *0.95*) Returns: (np.ndarray): labels (np.ndarray): Cluster labels. Shape: *(n_rois,)* """ ## Store parameter (but not data) args as attributes self.params['fit_sequentialHungarian'] = self._locals_to_params( locals_dict=locals(), keys=['thresh_cost',],) print(f"Clustering with CaImAn's sequential Hungarian algorithm method...") if self._verbose else None def find_matches(D_s): matches = [] costs = [] for ii, D in enumerate(D_s): # we make a copy not to set changes in the original DD = D.copy() if np.sum(np.where(np.isnan(DD))) > 0: raise Exception('Distance Matrix contains invalid value NaN') # we do the hungarian indexes = scipy.optimize.linear_sum_assignment(DD) indexes2 = [(ind1, ind2) for ind1, ind2 in zip(indexes[0], indexes[1])] matches.append(indexes) total = [] # we want to extract those informations from the hungarian algo for row, column in indexes2: value = DD[row, column] total.append(value) costs.append(total) # send back the results in the format we want return matches, costs n_roi = session_bool.sum(0) n_roi_cum = np.concatenate(([0], np.cumsum(n_roi))) matchings = [] matchings.append(list(range(n_roi[0]))) idx_union = np.arange(n_roi[0]) for i_sesh in tqdm(range(1,len(n_roi))): idx_sess = np.arange(n_roi_cum[i_sesh], n_roi_cum[i_sesh+1]) d_sub = d_conj[idx_sess][:, idx_union] D = np.ones((len(idx_sess), len(idx_union)))*np.logical_not((d_sub != 0).toarray())*1 + d_sub.toarray() D = [D] matches, costs = find_matches(D) matches = matches[0] costs = costs[0] # store indices idx_tp = np.where(np.array(costs) < thresh_cost)[0] if len(idx_tp) > 0: matched_ROIs1 = matches[0][idx_tp] # ground truth matched_ROIs2 = matches[1][idx_tp] # algorithm - comp non_matched1 = np.setdiff1d(list(range(D[0].shape[0])), matches[0][idx_tp]) non_matched2 = np.setdiff1d(list(range(D[0].shape[1])), matches[1][idx_tp]) TP = np.sum(np.array(costs) < thresh_cost) * 1. else: TP = 0. matched_ROIs1 = [] matched_ROIs2 = [] non_matched1 = list(range(D[0].shape[0])) non_matched2 = list(range(D[0].shape[1])) # compute precision and recall FN = D[0].shape[0] - TP FP = D[0].shape[1] - TP TN = 0 performance = dict() performance['recall'] = TP / (TP + FN) performance['precision'] = TP / (TP + FP) performance['accuracy'] = (TP + TN) / (TP + FP + FN + TN) performance['f1_score'] = 2 * TP / (2 * TP + FP + FN) mat_sess, mat_un, nm_sess, nm_un, performance, A2_len = matched_ROIs1, matched_ROIs2, non_matched1, non_matched2, performance, len(idx_union) idx_union[mat_un] = idx_sess[mat_sess] idx_union = np.concatenate((idx_union, idx_sess[nm_sess])) new_match = np.zeros(n_roi[i_sesh], dtype=int) new_match[mat_sess] = mat_un new_match[nm_sess] = range(A2_len, len(idx_union)) matchings.append(new_match.tolist()) self.seqHung_performance = performance labels = np.concatenate(matchings) u, c = np.unique(labels, return_counts=True) labels[np.isin(labels, u[c == 1])] = -1 labels = helpers.squeeze_integers(labels) self.labels = labels return self.labels
[docs] def make_conjunctive_distance_matrix( self, similarities: Dict[str, scipy.sparse.csr_array], mixing_params: Dict[str, Any], ) -> Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array, Dict[str, torch.Tensor]]: """ Makes a conjunctive distance matrix from the similarity matrices using the given mixing parameters. RH 2023 / 2025 Args: similarities (Dict[str, scipy.sparse.csr_array]): Dict mapping metric name to sparse similarity matrix. mixing_params (Dict[str, Any]): Mixing parameters dict. Expected keys: \n * ``power_<name>`` (float): Power for each metric. Defaults to 1.0 if not present. * ``sig_<name>_kwargs`` (Dict[str, float]): Sigmoid parameters ``{'mu': float, 'b': float}`` per metric. ``None`` or absent means no sigmoid. * ``p_norm`` (float): p-norm exponent for combining activated similarities. Returns: (Tuple): Tuple containing: dConj (scipy.sparse.csr_array): Conjunctive distance matrix (1 - sConj). sConj (scipy.sparse.csr_array): Conjunctive similarity matrix. activated_data (Dict[str, torch.Tensor]): Per-metric activated similarity data arrays. """ assert len(similarities) > 0, 'At least one similarity matrix must be provided.' ## Store parameter (but not data) args as attributes self.params['make_conjunctive_distance_matrix'] = self._locals_to_params( locals_dict=locals(), keys=['mixing_params'], ) p_norm = mixing_params.get('p_norm', 1.0) ## scalar p_norm = 1e-9 if p_norm == 0 else p_norm ## avoid division by zero ## Activate each metric: sigmoid → clamp → power ## Each metric's .data is a 1D array of shape (nnz,) activated_data = {} ## Dict[str, torch.Tensor(nnz,)] s_list = [] for name, s in similarities.items(): power = mixing_params.get(f'power_{name}', 1.0) sig_kwargs = mixing_params.get(f'sig_{name}_kwargs', None) activated = self._activation_function(s.data, sig_kwargs, power) activated_data[name] = activated s_list.append(activated) ## p-norm combination: sConj = (mean(s_k^p))^(1/p) sConj_data = self._pNorm(s_list=s_list, p=p_norm) ## shape: (nnz,) ## Build sparse sConj using the first similarity's sparsity pattern template = next(iter(similarities.values())) ## csr_array (n_roi, n_roi) sConj = template.copy() sConj.data = sConj_data.numpy() ## Distance = 1 - similarity dConj = sConj.copy() dConj.data = 1 - dConj.data return dConj, sConj, activated_data
def _activation_function( self, s: Optional[torch.Tensor] = None, sig_kwargs: Optional[Dict[str, float]] = {'mu':0.0, 'b':1.0}, power: Optional[float] = 1 ) -> Optional[torch.Tensor]: """ Applies an activation function to a similarity matrix. Args: s (Optional[torch.Tensor]): The input similarity matrix. If ``None``, the function returns ``None``. (Default is ``None``) sig_kwargs (Dict[str, float]): Keyword arguments for the sigmoid function applied to the similarity matrix. See helpers.generalised_logistic_function for details. (Default is {'mu':0.0, 'b':1.0}) power (Optional[float]): Power to which to raise the similarity. If ``None``, the power operation is not applied. (Default is *1*) Returns: (Optional[torch.Tensor]): Activated similarity matrix. Returns ``None`` if the input similarity matrix is ``None``. """ if s is None: return None s = torch.as_tensor(s, dtype=torch.float32) ## make functions such that if the param is None, then no operation is applied fn_sigmoid = lambda x, params: helpers.generalised_logistic_function(x, **params) if params is not None else x fn_power = lambda x, p: x ** p if p is not None else x return fn_power(torch.clamp(fn_sigmoid(s, sig_kwargs), min=0), power) def _pNorm( self, s_list: List[Optional[torch.Tensor]], p: float ) -> torch.Tensor: """ Calculate the p-norm of a list of similarity matrices. Args: s_list (List[Optional[torch.Tensor]]): List of similarity matrices. p (float): p-norm to use. Returns: (torch.Tensor): p-norm of the list of similarity matrices. """ s_list_noNones = [s for s in s_list if s is not None] return (torch.mean(torch.stack(s_list_noNones, axis=0)**p, dim=0))**(1/p)
[docs] def plot_similarity_relationships( self, max_samples: int = 1000000, kwargs_scatter: Dict[str, Union[int, float]] = {'s': 1, 'alpha': 0.1}, mixing_params: Optional[Dict[str, Any]] = None, ) -> Tuple[plt.figure, plt.axes]: """ Plot pairwise similarity relationships for all N*(N-1)/2 metric pairs. Each subplot shows one pair of metrics, colored by conjunctive distance. Args: max_samples (int): Maximum number of samples to plot. kwargs_scatter (Dict[str, Union[int, float]]): Keyword arguments for ``matplotlib.pyplot.scatter``. mixing_params (Optional[Dict[str, Any]]): Mixing parameters for ``make_conjunctive_distance_matrix``. If ``None``, uses ``self.best_params`` if available, else defaults. Returns: (Tuple[matplotlib.pyplot.figure, matplotlib.pyplot.axes]): fig, axs: Figure and axes objects. """ if mixing_params is None: mixing_params = getattr(self, 'best_params', {'p_norm': -4.0}) dConj, sConj, activated_data = self.make_conjunctive_distance_matrix( similarities=self.similarities, mixing_params=mixing_params, ) ## Generate all N*(N-1)/2 metric pairs import itertools metric_names = list(self.similarities.keys()) pairs = list(itertools.combinations(metric_names, 2)) n_pairs = max(len(pairs), 1) ## Subsample for plotting idx_rand = np.floor( np.random.rand(min(max_samples, len(dConj.data))) * len(dConj.data) ).astype(int) d_conj_sub = dConj.data[idx_rand] fig, axs = plt.subplots(nrows=1, ncols=n_pairs, figsize=(7 * n_pairs, 4)) if n_pairs == 1: axs = [axs] fig.suptitle('Similarity relationships', fontsize=16) for i, (name_x, name_y) in enumerate(pairs): x_data = activated_data[name_x][idx_rand] y_data = activated_data[name_y][idx_rand] axs[i].scatter(x_data, y_data, c=d_conj_sub, **kwargs_scatter) axs[i].set_xlabel(f'sim {name_x}') axs[i].set_ylabel(f'sim {name_y}') return fig, axs
[docs] def plot_distSame(self, mixing_params: Optional[dict] = None) -> None: """ Plot the estimated distribution of the pairwise similarities between matched ROI pairs of ROIs. Args: mixing_params (Optional[dict]): Mixing parameters for ``make_conjunctive_distance_matrix``. If ``None``, the function uses the object's best parameters. (Default is ``None``) """ params = mixing_params if mixing_params is not None else self.best_params dConj, sConj, activated_data = self.make_conjunctive_distance_matrix( similarities=self.similarities, mixing_params=params, ) dens_same_crop, dens_same, dens_diff, dens_all, edges, d_crossover = self._separate_diffSame_distributions(dConj) if edges is None: print('No crossover found, not plotting') return None fig = plt.figure() plt.stairs(dens_same, edges, linewidth=5) plt.stairs(dens_same_crop, edges, linewidth=3) plt.stairs(dens_diff, edges) plt.stairs(dens_all, edges) plt.axvline(d_crossover, color='k', linestyle='--') plt.ylim([dens_same.max()*-0.5, dens_same.max()*1.5]) plt.title('Pairwise similarities') plt.xlabel('distance or prob(different)') plt.ylabel('counts or density') plt.legend(['same', 'same (cropped)', 'diff', 'all', 'crossover']) return fig
def _fn_smooth( self, x: torch.Tensor, ) -> torch.Tensor: """ Smooth a 1D tensor with a boxcar convolution. Args: x (torch.Tensor): 1D tensor to be smoothed. Returns: (torch.Tensor): Smoothed tensor, same shape as ``x``. """ return helpers.Convolver_1d( kernel=torch.ones(self.smooth_window), length_x=self.n_bins, pad_mode='same', correct_edge_effects=True, device='cpu', ).convolve(x) #################################################################### ## Shared histogram overlap computation #################################################################### def _compute_histogram_overlap( self, distances: torch.Tensor, intra_indices: torch.Tensor, edges: torch.Tensor, smoother: 'helpers.Convolver_1d', scale_factor: float, ) -> Tuple[float, torch.Tensor, torch.Tensor]: """ Compute the hard histogram overlap loss between the estimated 'same' and 'different' distance distributions. Shared core used by :meth:`_separate_diffSame_distributions` and :meth:`_find_optimal_parameters_DE` (scalar objective). The 'different' distribution is estimated by scaling the intra-session (known-different) histogram. The 'same' distribution is the residual after subtracting the scaled intra counts from all counts, clamped non-negative and then smoothed. The loss is the dot product of the two estimated distributions (overlap area). If no valid crossover point exists (i.e. the two distributions never separate), ``loss`` is set to ``1e6`` as a penalty. RH 2025 Args: distances (torch.Tensor): 1D float tensor of distance values, shape ``(n,)``. intra_indices (torch.Tensor): 1D int64 tensor of indices into ``distances`` corresponding to intra-session (known-different) pairs. edges (torch.Tensor): 1D float tensor of bin edges, shape ``(n_bins + 1,)``. smoother (helpers.Convolver_1d): Pre-built 1D convolver for smoothing the 'same' distribution. scale_factor (float): ``n_all / n_intra`` — multiplier that scales the intra counts up to the full population size. Returns: (Tuple[float, torch.Tensor, torch.Tensor]): loss (float): Overlap area (dot product of dens_same and dens_diff), or ``1e6`` if no valid crossover exists. dens_same (torch.Tensor): Smoothed estimated 'same' distribution, shape ``(n_bins,)``. dens_diff (torch.Tensor): Scaled intra-session distribution (estimated 'different'), shape ``(n_bins,)``. """ ## Histogram all distances and intra-session distances counts_all, _ = torch.histogram(distances, edges) counts_intra, _ = torch.histogram(distances[intra_indices], edges) ## Scale intra counts to estimate the full 'different' distribution dens_diff = counts_intra * scale_factor ## 'Same' distribution = residual, clamped non-negative, then smoothed dens_same = smoother.convolve(torch.clamp(counts_all - dens_diff, min=0)) ## Penalize if no valid crossover exists between the two distributions dens_deriv = dens_diff - dens_same dens_deriv[int(dens_diff.argmax().item()):] = 0 if not (dens_deriv < 0).any(): return 1e6, dens_same, dens_diff return float((dens_same * dens_diff).sum().item()), dens_same, dens_diff def _separate_diffSame_distributions( self, d_conj: scipy.sparse.csr_array, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, float]: """ Estimate the 'same' and 'different' distance distributions from a conjunctive distance matrix. The 'same' distribution is estimated as the residual after subtracting the scaled intra-session (known-different) distribution from the overall distribution. Delegates core histogram overlap computation to :meth:`_compute_histogram_overlap`. Args: d_conj (scipy.sparse.csr_array): Conjunctive distance matrix. Returns: (tuple): tuple containing: dens_same_crop (np.ndarray): 'Same' distribution with values below the crossover point zeroed out. dens_same (np.ndarray): Un-cropped smoothed 'same' distribution. dens_diff (np.ndarray): Scaled intra-session 'different' distribution. dens_all (np.ndarray): Raw histogram counts over all pairs. edges (np.ndarray): Bin edges used for all histograms. d_crossover (float): Distance at which the 'same' and 'different' distributions cross over. """ ## Bin edges covering the full [0, 1] distance range edges = torch.linspace(start=0, end=1, steps=self.n_bins + 1, dtype=torch.float32) ## Extract all-pairs distances and compute raw histogram dist_all = torch.as_tensor(d_conj.data, dtype=torch.float32) dens_all, _ = torch.histogram(dist_all, edges) ## Extract intra-session (known-different) distances via s_sesh_inv mask d_intra = d_conj.multiply(self.s_sesh_inv) d_intra.eliminate_zeros() if len(d_intra.data) == 0: return None, None, None, None, None, None dist_intra = torch.as_tensor(d_intra.data, dtype=torch.float32) ## Scale factor: ratio of all pairs to intra-session pairs n_all = int(dens_all.sum().item()) n_intra = dist_intra.shape[0] scale_factor = float(n_all) / max(n_intra, 1) ## Compute scaled intra (different) distribution counts_intra, _ = torch.histogram(dist_intra, edges) dens_diff = counts_intra * scale_factor ## Smoothed residual (same) — uses cached smoother in _fn_smooth dens_same = self._fn_smooth(torch.clamp(dens_all - dens_diff, min=0)) ## Locate crossover: last bin before dens_diff peak where diff > same dens_deriv = dens_diff - dens_same dens_deriv[int(dens_diff.argmax().item()):] = 0 crossover_candidates = torch.where(dens_deriv < 0)[0] if crossover_candidates.shape[0] == 0: return None, None, None, None, None, None idx_crossover = int(crossover_candidates[-1].item()) + 1 d_crossover = float(edges[idx_crossover].item()) ## Zero out 'same' distribution below the crossover point dens_same_crop = dens_same.clone() dens_same_crop[idx_crossover:] = 0 return dens_same_crop, dens_same, dens_diff, dens_all, edges, d_crossover def _extract_hdbscan_quality_metrics(self) -> Dict: """ Extract HDBSCAN quality metrics, handling differences between legacy hdbscan and fast_hdbscan backends. Legacy hdbscan (with attach_fully_connected_node) stores an extra trailing element for the fully connected node that must be stripped. fast_hdbscan does not add extra nodes but exposes additional attributes: ``_core_distances`` (per-point core distance) and ``_min_spanning_tree`` (MST edge array). Returns: (Dict): Dictionary with keys: \n * ``sample_probabilities``: list of floats, per-ROI membership strength. * ``sample_outlierScores``: list of floats (legacy) or ``None`` (fast_hdbscan). * ``sample_coreDistances``: list of floats (fast_hdbscan) or ``None`` (legacy). Per-point core distance used for mutual reachability. * ``mst_edge_weights``: list of floats (fast_hdbscan) or ``None`` (legacy). Sorted MST edge weights; useful for diagnosing cluster separation. """ def to_list_of_floats(x): return [float(i) for i in x] if x is not None else None used_fcn = getattr(self, '_fit_used_fully_connected_node', True) ## Probabilities probs = self.hdbs.probabilities_ if used_fcn: probs = probs[:-1] result = { 'sample_probabilities': to_list_of_floats(probs), } ## Outlier scores (only available in legacy hdbscan, not fast_hdbscan) if hasattr(self.hdbs, 'outlier_scores_'): scores = self.hdbs.outlier_scores_ if used_fcn: scores = scores[:-1] result['sample_outlierScores'] = to_list_of_floats(scores) else: result['sample_outlierScores'] = None ## Core distances (fast_hdbscan only) core_dists = getattr(self.hdbs, '_core_distances', None) if core_dists is not None: if used_fcn: core_dists = core_dists[:-1] result['sample_coreDistances'] = to_list_of_floats(core_dists) else: result['sample_coreDistances'] = None ## MST edge weights (fast_hdbscan only) mst = getattr(self.hdbs, '_min_spanning_tree', None) if mst is not None: ## MST is (n-1, 3) array with columns [src, dst, weight] result['mst_edge_weights'] = to_list_of_floats(np.sort(mst[:, 2])) else: result['mst_edge_weights'] = None return result
[docs] def compute_quality_metrics( self, sim_mat: Optional[object] = None, dist_mat: Optional[object] = None, labels: Optional[np.ndarray] = None, ) -> Dict: """ Computes quality metrics of the dataset. RH 2023 Args: sim_mat (Optional[object]): Similarity matrix of shape *(n_samples, n_samples)*. If ``None`` then self.sConj must exist. (Default is ``None``) dist_mat (Optional[object]): Distance matrix of shape *(n_samples, n_samples)*. If ``None`` then self.dConj must exist. (Default is ``None``) labels (Optional[np.ndarray]): Cluster labels of shape *(n_samples,)*. If ``None``, then self.labels must exist. (Default is ``None``) Returns: (Dict): quality_metrics (Dict): Quality metrics dictionary that includes: 'cluster_intra_means', 'cluster_intra_mins', 'cluster_intra_maxs', 'cluster_silhouette', 'sample_silhouette', and other metrics if available. """ if sim_mat is None: assert hasattr(self, 'sConj'), "self.sConj does not exist. Run self.find_optimal_parameters_for_pruning() first or specify sim_mat." sim_mat = self.sConj if dist_mat is None: assert hasattr(self, 'dConj'), "self.dConj does not exist. Run self.find_optimal_parameters_for_pruning() first or specify dist_mat." dist_mat = self.dConj if labels is None: assert hasattr(self, 'labels'), "self.labels does not exist. Run self.find_optimal_parameters_for_pruning() first or specify labels." labels = self.labels assert scipy.sparse.issparse(sim_mat), "sim_mat must be a scipy.sparse.csr_array." assert scipy.sparse.issparse(dist_mat), "dist_mat must be a scipy.sparse.csr_array." labels_unique, cs_intra_means, cs_intra_mins, cs_intra_maxs, cs_sil = cluster_quality_metrics( sim=sim_mat, labels=labels, ) import sklearn import sparse d_dense = sparse.COO(dist_mat.copy().tocsr()).astype(np.float16) d_dense.fill_value = (dist_mat.data.max() - dist_mat.data.min()).astype(np.float16) * 10 d_dense = d_dense.todense() np.fill_diagonal(d_dense, 0) ## Number of labels must be at least 2 if len(np.unique(labels)) < 2: warnings.warn(f"Silhouette samples calculation requires at least 2 labels. Returning None. Found {len(np.unique(labels))} labels.") rs_sil = None else: rs_sil = sklearn.metrics.silhouette_samples(X=d_dense, labels=labels, metric='precomputed') def to_list_of_floats(x): return [float(i) for i in x] if x is not None else None ## Extract HDBSCAN-specific metrics hdbscan_metrics = self._extract_hdbscan_quality_metrics() if hasattr(self, 'hdbs') else None self.quality_metrics = util.JSON_Dict({ 'cluster_labels_unique': to_list_of_floats(labels_unique), 'cluster_intra_means': to_list_of_floats(cs_intra_means), 'cluster_intra_mins': to_list_of_floats(cs_intra_mins), 'cluster_intra_maxs': to_list_of_floats(cs_intra_maxs), 'cluster_silhouette': to_list_of_floats(cs_sil), 'sample_silhouette': to_list_of_floats(rs_sil), 'sample_probabilities': hdbscan_metrics['sample_probabilities'] if hdbscan_metrics else None, 'hdbscan': hdbscan_metrics, 'sequentialHungarian': { 'performance_recall': float(self.seqHung_performance['recall']), 'performance_precision': float(self.seqHung_performance['precision']), 'performance_f1': float(self.seqHung_performance['f1_score']), 'performance_accuracy': float(self.seqHung_performance['accuracy']), } if hasattr(self, 'seqHung_performance') else None, }) return self.quality_metrics
#################################################################### #################################################################### ## ## LEGACY / COMPARISON METHODS ## ## The methods below are retained for backward compatibility and ## benchmarking. They are NOT used by the default pipeline. ## ## The recommended approach is find_optimal_parameters_for_pruning() ## which calls _find_optimal_parameters_DE(freeze_sigmoid=True). ## #################################################################### #################################################################### def _find_optimal_parameters_for_pruning_optuna( self, kwargs_findParameters: Dict[str, Union[int, float, bool]] = { 'n_patience': 100, 'tol_frac': 0.05, 'max_trials': 350, 'max_duration': 60*10, 'value_stop': 0.0, }, bounds_findParameters: Dict[str, List[float]] = { 'power_nn': [0.0, 2.], 'power_swt': [0.0, 2.], 'p_norm': [-5, -0.1], 'sig_nn_kwargs_mu': [0., 1.0], 'sig_nn_kwargs_b': [0.1, 1.5], 'sig_swt_kwargs_mu': [0., 1.0], 'sig_swt_kwargs_b': [0.1, 1.5], }, n_jobs_findParameters: int = -1, n_bins: Optional[int] = None, smoothing_window_bins: Optional[int] = None, seed=None, ) -> Dict: """ **LEGACY** — Original Optuna TPE optimizer (7-param). Superseded by :meth:`find_optimal_parameters_for_pruning` which uses NB calibration + freeze-sigmoid DE and achieves ~3.6x better separation quality. Requires ``optuna`` to be installed. RH 2023 Args: kwargs_findParameters (Dict[str, Union[int, float, bool]]): Keyword arguments for the Convergence_checker class __init__. bounds_findParameters (Dict[str, List[float]]): Bounds for the 7 parameters to be optimized. n_jobs_findParameters (int): Number of parallel Optuna jobs. ``-1`` = all cores. n_bins (Optional[int]): Overwrites ``n_bins`` from ``__init__``. smoothing_window_bins (Optional[int]): Overwrites ``smoothing_window_bins`` from ``__init__``. seed (Optional[int]): Random seed. Returns: (Dict): kwargs_makeConjunctiveDistanceMatrix_best (Dict): Optimal parameters for :meth:`make_conjunctive_distance_matrix`. """ import optuna self.params['find_optimal_parameters_for_pruning'] = self._locals_to_params( locals_dict=locals(), keys=[ 'kwargs_findParameters', 'bounds_findParameters', 'n_jobs_findParameters', 'n_bins', 'smoothing_window_bins', 'seed', ], ) self.n_bins = self.n_bins if n_bins is None else n_bins self.smooth_window = self.smooth_window if smoothing_window_bins is None else smoothing_window_bins self.bounds_findParameters = bounds_findParameters self._seed = seed np.random.seed(self._seed) print('Finding mixing parameters using automated hyperparameter tuning...') if self._verbose else None optuna.logging.set_verbosity(optuna.logging.WARNING) self.checker = helpers.Convergence_checker_optuna(verbose=self._verbose >= 2, **kwargs_findParameters) prog_bar = helpers.OptunaProgressBar( n_trials=kwargs_findParameters['max_trials'], mininterval=5.0, ) self.study = optuna.create_study(direction='minimize', sampler=optuna.samplers.TPESampler( n_startup_trials=kwargs_findParameters['n_patience'] // 2, seed=self._seed, )) self.study.optimize( func=self._objectiveFn_distSameMagnitude, n_jobs=n_jobs_findParameters, callbacks=[self.checker.check, prog_bar], n_trials=kwargs_findParameters['max_trials'], show_progress_bar=False, ) self.best_params = self.study.best_params.copy() [self.best_params.pop(p) for p in [ 'sig_nn_kwargs_mu', 'sig_nn_kwargs_b', 'sig_swt_kwargs_mu', 'sig_swt_kwargs_b', ] if p in self.best_params.keys()] self.best_params['sig_nn_kwargs'] = { 'mu': self.study.best_params['sig_nn_kwargs_mu'], 'b': self.study.best_params['sig_nn_kwargs_b'], } self.best_params['sig_swt_kwargs'] = { 'mu': self.study.best_params['sig_swt_kwargs_mu'], 'b': self.study.best_params['sig_swt_kwargs_b'], } self.kwargs_makeConjunctiveDistanceMatrix_best = { 'power_sf': None, 'power_nn': None, 'power_swt': None, 'p_norm': None, 'sig_sf_kwargs': None, 'sig_nn_kwargs': None, 'sig_swt_kwargs': None, } self.kwargs_makeConjunctiveDistanceMatrix_best.update(self.best_params) print(f'Completed automatic parameter search. Best value found: {self.study.best_value} with parameters {self.best_params}') if self._verbose else None return self.kwargs_makeConjunctiveDistanceMatrix_best def _objectiveFn_distSameMagnitude( self, trial: object, ) -> float: """ **LEGACY** — Optuna objective function for histogram overlap loss. Used by :meth:`_find_optimal_parameters_for_pruning_optuna`. RH 2023 """ power_NN = trial.suggest_float('power_nn', *self.bounds_findParameters['power_nn'], log=False) power_SWT = trial.suggest_float('power_swt', *self.bounds_findParameters['power_swt'], log=False) p_norm = trial.suggest_float('p_norm', *self.bounds_findParameters['p_norm'], log=False) sig_NN_kwargs = { 'mu': trial.suggest_float('sig_nn_kwargs_mu', *self.bounds_findParameters['sig_nn_kwargs_mu'], log=False), 'b': trial.suggest_float('sig_nn_kwargs_b', *self.bounds_findParameters['sig_nn_kwargs_b'], log=False), } sig_SWT_kwargs = { 'mu': trial.suggest_float('sig_swt_kwargs_mu', *self.bounds_findParameters['sig_swt_kwargs_mu'], log=False), 'b': trial.suggest_float('sig_swt_kwargs_b', *self.bounds_findParameters['sig_swt_kwargs_b'], log=False), } mixing_params = { 'power_sf': 1.0, 'power_nn': power_NN, 'power_swt': power_SWT, 'p_norm': p_norm, 'sig_nn_kwargs': sig_NN_kwargs, 'sig_swt_kwargs': sig_SWT_kwargs, } dConj, sConj, activated_data = self.make_conjunctive_distance_matrix( similarities=self.similarities, mixing_params=mixing_params, ) dens_same_crop, dens_same, dens_diff, dens_all, edges, d_crossover = self._separate_diffSame_distributions(dConj) if dens_same_crop is None: return 0 return (dens_same * dens_diff).sum().item()
## Lazy-compiled numba kernel for weighted Jaccard. Compiled on first call ## to avoid numba import overhead (subprocess spam on macOS) at module load. _weighted_jaccard_csr_kernel = None def _get_weighted_jaccard_kernel(): """ Lazily compile and cache the numba kernel for weighted Jaccard. Returns the compiled ``_weighted_jaccard_csr_kernel`` function. """ global _weighted_jaccard_csr_kernel if _weighted_jaccard_csr_kernel is not None: return _weighted_jaccard_csr_kernel import numba @numba.njit(cache=True) def kernel(indptr, indices, data, out_data): """ Numba kernel: weighted Jaccard (Ruzicka) similarity for each edge in a symmetric CSR similarity matrix. For each stored edge (i, j), computes: J_w(i,j) = Σ_k min(s_ik, s_jk) / Σ_k max(s_ik, s_jk) where the sum runs over all k in the union of neighbors of i and j, **excluding k = i and k = j** (the direct edge endpoints). Assumptions: - Symmetric matrix (s_ij = s_ji for all stored pairs). - Zero diagonal (no self-loops stored). - Indices within each row are sorted ascending (CSR standard). - All data values are non-negative. """ n = len(indptr) - 1 for i in range(n): for ptr_ij in range(indptr[i], indptr[i + 1]): j = indices[ptr_ij] ## Merge-scan sorted neighbor lists of rows i and j pi = indptr[i] pi_end = indptr[i + 1] pj = indptr[j] pj_end = indptr[j + 1] sum_min = 0.0 sum_max = 0.0 while pi < pi_end and pj < pj_end: ki = indices[pi] kj = indices[pj] if ki == kj: ## Shared neighbor — skip if it is one of the edge endpoints if ki != i and ki != j: vi = data[pi] vj = data[pj] if vi < vj: sum_min += vi sum_max += vj else: sum_min += vj sum_max += vi pi += 1 pj += 1 elif ki < kj: ## Neighbor of i only — skip if it is j if ki != j: sum_max += data[pi] pi += 1 else: ## Neighbor of j only — skip if it is i if kj != i: sum_max += data[pj] pj += 1 ## Drain remaining from row i while pi < pi_end: if indices[pi] != j: sum_max += data[pi] pi += 1 ## Drain remaining from row j while pj < pj_end: if indices[pj] != i: sum_max += data[pj] pj += 1 if sum_max > 0.0: out_data[ptr_ij] = sum_min / sum_max else: out_data[ptr_ij] = 0.0 _weighted_jaccard_csr_kernel = kernel return _weighted_jaccard_csr_kernel
[docs] def weighted_jaccard_similarity( s: scipy.sparse.csr_array, ) -> scipy.sparse.csr_array: """ Compute weighted Jaccard (Ruzicka) similarity from a sparse similarity graph. For each pair (i, j) in the input, replaces the direct similarity ``s_ij`` with a neighborhood-based structural similarity: .. math:: J_w(i, j) = \\frac{\\sum_{k \\neq i,j} \\min(s_{ik}, s_{jk})} {\\sum_{k \\neq i,j} \\max(s_{ik}, s_{jk})} where the sum is over all ``k`` in the union of non-zero neighbors of ``i`` and ``j``, excluding ``k = i`` and ``k = j``. This is a second-order similarity: two nodes score high if they share many strong connections to the same neighbors. Acts as a denoising step that amplifies community structure. Used in SNN clustering (Seurat), Louvain/Leiden, and UMAP local connectivity. Uses a numba-accelerated merge-scan over sorted CSR rows. Complexity: O(nnz * avg_degree). Args: s (scipy.sparse.csr_array): Sparse symmetric similarity matrix. Shape: *(n, n)*. Must have non-negative values, zero diagonal (not stored), and sorted indices within each row (standard CSR convention). Returns: (scipy.sparse.csr_array): s_jaccard (scipy.sparse.csr_array): Weighted Jaccard similarity matrix. Same sparsity pattern as input. Values in [0, 1]. Symmetric if input is symmetric. Zero diagonal. """ assert isinstance(s, scipy.sparse.csr_array), ( f"Expected scipy.sparse.csr_array, got {type(s)}" ) assert s.shape[0] == s.shape[1], "Matrix must be square." ## Ensure sorted indices (required for merge-scan) s_sorted = s if not s.has_sorted_indices: s_sorted = s.copy() s_sorted.sort_indices() data_f64 = s_sorted.data.astype(np.float64) out_data = np.empty(len(data_f64), dtype=np.float64) kernel = _get_weighted_jaccard_kernel() kernel( indptr=s_sorted.indptr, indices=s_sorted.indices, data=data_f64, out_data=out_data, ) s_jaccard = s_sorted.copy() s_jaccard.data = out_data.astype(s.data.dtype) return s_jaccard
## Lazy-compiled numba kernel for noise rescue Kruskal. Same pattern as ## weighted Jaccard: compiled on first call to avoid numba import overhead. _noise_rescue_kruskal_kernel = None def _get_noise_rescue_kernel(): """ Lazily compile and cache the numba kernel for noise rescue Kruskal. Returns the compiled ``_noise_rescue_kruskal_kernel`` function. """ global _noise_rescue_kruskal_kernel if _noise_rescue_kruskal_kernel is not None: return _noise_rescue_kruskal_kernel import numba @numba.njit(cache=True) def kernel( indptr, ## int32[:], CSR row pointers indices, ## int32[:], CSR column indices data, ## float64[:], CSR edge distances labels, ## int64[:], Phase 1 labels (-1 = noise) group_labels, ## int32[:], session index per ROI n_groups, ## int, number of sessions d_cutoff, ## float64, maximum distance for edge acceptance ): """ Kruskal-style noise rescue with DSU + bitmask cannot-link constraints. Given Phase 1 cluster labels, processes edges from the distance graph where at least one endpoint is noise (label == -1), sorted by distance ascending. Merges are accepted if: 1. d <= d_cutoff 2. endpoints are in different DSU components 3. no session-conflict (bitmask AND == 0) The DSU is pre-initialized with Phase 1 clusters: all members of the same cluster are pre-merged and their component bitmask is the OR of their session bits. After processing, labels are extracted: - Components containing Phase 1 cluster members inherit that label - Components of only ex-noise points with size >= 2 get new labels (starting from max_existing_label + 1) - Remaining singletons stay -1 Returns new_labels: int64[:] of length n_rois. """ n = len(indptr) - 1 ## n_rois ## -- DSU arrays -- parent = np.arange(n, dtype=np.int64) rank = np.zeros(n, dtype=np.int32) ## -- Bitmask arrays: comp_mask[root, word] -- n_words = (n_groups + 63) // 64 comp_mask = np.zeros((n, n_words), dtype=np.uint64) ## Initialize bitmasks from group_labels for i in range(n): g = group_labels[i] if g >= 0: w = g // 64 b = np.uint64(g % 64) comp_mask[i, w] = comp_mask[i, w] | (np.uint64(1) << b) ## -- Pre-merge Phase 1 clusters -- ## For each cluster label > -1, union all its members. ## First pass: find max label to size the bookkeeping. max_label = -1 for i in range(n): if labels[i] > max_label: max_label = labels[i] if max_label >= 0: ## For each cluster, track first member seen as the anchor cluster_anchor = np.full(max_label + 1, -1, dtype=np.int64) for i in range(n): lbl = labels[i] if lbl < 0: continue if cluster_anchor[lbl] < 0: cluster_anchor[lbl] = i else: ## Union i with the anchor anchor = cluster_anchor[lbl] ## Find root of anchor root_a = anchor while parent[root_a] != root_a: root_a = parent[root_a] curr = anchor while curr != root_a: nxt = parent[curr] parent[curr] = root_a curr = nxt ## Find root of i root_i = i while parent[root_i] != root_i: root_i = parent[root_i] curr = i while curr != root_i: nxt = parent[curr] parent[curr] = root_i curr = nxt if root_a != root_i: ## Union by rank if rank[root_a] > rank[root_i]: new_root = root_a old_root = root_i elif rank[root_a] < rank[root_i]: new_root = root_i old_root = root_a else: new_root = root_a old_root = root_i rank[new_root] += 1 parent[old_root] = new_root ## Merge bitmasks for ww in range(n_words): comp_mask[new_root, ww] = ( comp_mask[new_root, ww] | comp_mask[old_root, ww] ) ## -- Collect edges where at least one endpoint is noise -- ## Count first n_edges = 0 for i in range(n): for ptr in range(indptr[i], indptr[i + 1]): j = indices[ptr] if j <= i: continue ## upper triangle only (avoid duplicates) if labels[i] == -1 or labels[j] == -1: n_edges += 1 ## Allocate and fill edge_u = np.empty(n_edges, dtype=np.int64) edge_v = np.empty(n_edges, dtype=np.int64) edge_d = np.empty(n_edges, dtype=np.float64) idx = 0 for i in range(n): for ptr in range(indptr[i], indptr[i + 1]): j = indices[ptr] if j <= i: continue if labels[i] == -1 or labels[j] == -1: edge_u[idx] = i edge_v[idx] = j edge_d[idx] = data[ptr] idx += 1 ## -- Sort edges by distance ascending -- sort_order = np.argsort(edge_d) ## -- Kruskal traversal -- for idx in range(len(sort_order)): eidx = sort_order[idx] d_val = edge_d[eidx] ## Stop if beyond cutoff if d_val > d_cutoff: break u = edge_u[eidx] v = edge_v[eidx] ## Find root of u with path compression root_u = u while parent[root_u] != root_u: root_u = parent[root_u] curr = u while curr != root_u: nxt = parent[curr] parent[curr] = root_u curr = nxt ## Find root of v with path compression root_v = v while parent[root_v] != root_v: root_v = parent[root_v] curr = v while curr != root_v: nxt = parent[curr] parent[curr] = root_v curr = nxt ## Already same component if root_u == root_v: continue ## Conflict check: bitmask AND conflict = False for ww in range(n_words): if (comp_mask[root_u, ww] & comp_mask[root_v, ww]) != np.uint64(0): conflict = True break if conflict: continue ## Union by rank if rank[root_u] > rank[root_v]: new_root = root_u old_root = root_v elif rank[root_u] < rank[root_v]: new_root = root_v old_root = root_u else: new_root = root_u old_root = root_v rank[new_root] += 1 parent[old_root] = new_root ## Merge bitmasks for ww in range(n_words): comp_mask[new_root, ww] = ( comp_mask[new_root, ww] | comp_mask[old_root, ww] ) ## -- Extract labels from DSU -- ## Find root for every node (with path compression) for i in range(n): root_i = i while parent[root_i] != root_i: root_i = parent[root_i] curr = i while curr != root_i: nxt = parent[curr] parent[curr] = root_i curr = nxt ## Map: root → Phase 1 label (if component has one) ## Also count component sizes root_to_label = np.full(n, -1, dtype=np.int64) comp_size = np.zeros(n, dtype=np.int64) for i in range(n): root_i = parent[i] comp_size[root_i] += 1 if labels[i] >= 0: root_to_label[root_i] = labels[i] ## Assign new labels for noise-only components of size >= 2 next_label = max_label + 1 if max_label >= 0 else 0 for i in range(n): if parent[i] == i and root_to_label[i] < 0 and comp_size[i] >= 2: root_to_label[i] = next_label next_label += 1 ## Build output labels new_labels = np.empty(n, dtype=np.int64) for i in range(n): new_labels[i] = root_to_label[parent[i]] return new_labels _noise_rescue_kruskal_kernel = kernel return _noise_rescue_kruskal_kernel
[docs] def noise_rescue_kruskal( d_conj: scipy.sparse.csr_array, labels: np.ndarray, group_labels: np.ndarray, n_groups: int, d_cutoff: float, ) -> np.ndarray: """ Assign HDBSCAN noise points to nearby clusters (or nucleate new small clusters) using a Kruskal-style sorted-edge traversal with DSU and bitmask cannot-link constraints. This is Phase 2 of a two-phase clustering strategy: * **Phase 1**: HDBSCAN with ``min_samples > 1`` produces robust core clusters but marks many ROIs as noise (``label == -1``). * **Phase 2** (this function): Processes edges from the distance graph where at least one endpoint is noise, sorted by distance. Merges are accepted only if no session conflict arises (checked via ``uint64`` bitmask per DSU component). This can either assign noise points to existing Phase 1 clusters or nucleate new clusters when 2+ noise points are mutual neighbors within ``d_cutoff``. The algorithm uses the same DSU + bitmask pattern as ``fast_hdbscan``'s ``_kruskal_core_group_constrained``, but with the DSU **pre-initialized** from Phase 1's pre-formed clusters. Args: d_conj (scipy.sparse.csr_array): Sparse distance matrix (inter-session masked). Shape: *(n_rois, n_rois)*. Must be symmetric with sorted indices. labels (np.ndarray): Phase 1 cluster labels. Shape: *(n_rois,)*. ``-1`` = noise. group_labels (np.ndarray): Session index per ROI (``int32``). Shape: *(n_rois,)*. n_groups (int): Number of distinct sessions (groups). d_cutoff (float): Maximum edge distance to accept. Edges with ``d > d_cutoff`` are ignored. Returns: (np.ndarray): new_labels (np.ndarray): Updated cluster labels. Shape: *(n_rois,)*. Noise points that were rescued get their assigned cluster label; new clusters of 2+ ex-noise points get fresh label IDs; remaining singletons stay ``-1``. """ assert isinstance(d_conj, scipy.sparse.csr_array), ( f"Expected scipy.sparse.csr_array, got {type(d_conj)}" ) assert d_conj.shape[0] == d_conj.shape[1], "Distance matrix must be square." n = d_conj.shape[0] assert labels.shape == (n,), f"labels shape {labels.shape} != ({n},)" assert group_labels.shape == (n,), f"group_labels shape {group_labels.shape} != ({n},)" ## Ensure sorted indices for consistent edge iteration d_sorted = d_conj if not d_conj.has_sorted_indices: d_sorted = d_conj.copy() d_sorted.sort_indices() kernel = _get_noise_rescue_kernel() new_labels = kernel( indptr=d_sorted.indptr.astype(np.int32), indices=d_sorted.indices.astype(np.int32), data=d_sorted.data.astype(np.float64), labels=labels.astype(np.int64), group_labels=group_labels.astype(np.int32), n_groups=int(n_groups), d_cutoff=float(d_cutoff), ) return new_labels
[docs] def attach_fully_connected_node( d: object, dist_fullyConnectedNode: Optional[float] = None, n_nodes: int = 1, ) -> object: """ Appends a single node to a sparse distance graph that is weakly connected to all nodes. Args: d (object): Sparse graph with multiple components. Refer to scipy.sparse.csgraph.connected_components for details. dist_fullyConnectedNode (Optional[float]): Value used for the connection strength to all other nodes. This value will be appended as elements in a new row and column at the ends of the 'd' matrix. If ``None``, then the value will be set to 1000 times the difference between the maximum and minimum values in 'd'. (Default is ``None``) n_nodes (int): Number of nodes to append to the graph. (Default is *1*) Returns: (object): d2 (object): Sparse graph with only one component. """ if dist_fullyConnectedNode is None: dist_fullyConnectedNode = (d.max() - d.min()) * 1000 d2 = d.copy() d2 = scipy.sparse.vstack((d2, np.ones((n_nodes,d2.shape[1]), dtype=d.dtype)*dist_fullyConnectedNode)) d2 = scipy.sparse.hstack((d2, np.ones((d2.shape[0],n_nodes), dtype=d.dtype)*dist_fullyConnectedNode)) return d2.tocsr()
[docs] def score_labels( labels_test: np.ndarray, labels_true: np.ndarray, ignore_negOne: bool = False, thresh_perfect: float = 0.9999999999, ) -> Dict[str, Union[float, Tuple[int, int]]]: """ Computes the score of the clustering by finding the best match using the linear sum assignment problem. The score is bounded between 0 and 1. Note: The score is not symmetric if the number of true and test labels are not the same. I.e., switching ``labels_test`` and ``labels_true`` can lead to different scores. This is because we are scoring how well each true set is matched by an optimally assigned test set. RH 2022 Args: labels_test (np.ndarray): Labels of the test clusters/sets. (shape: *(n,)*) labels_true (np.ndarray): Labels of the true clusters/sets. (shape: *(n,)*) ignore_negOne (bool): Whether to ignore ``-1`` values in the labels. If set to ``True``, ``-1`` values will be ignored in the computation. (Default is ``False``) thresh_perfect (float): Threshold for perfect match. Mostly used for numerical stability. (Default is *0.9999999999*) Returns: (dict): dictionary containing: score_weighted_partial (float): Average correlation between the best matched sets of true and test labels, weighted by the number of elements in each true set. score_weighted_perfect (float): Fraction of perfect matches, weighted by the number of elements in each true set. score_unweighted_partial (float): Average correlation between the best matched sets of true and test labels. score_unweighted_perfect (float): Fraction of perfect matches. adj_rand_score (float): Adjusted Rand score of the labels. adj_mutual_info_score (float): Adjusted mutual info score of the labels. None if ``compute_mutual_info`` is ``False``. ignore_negOne (bool): Whether ``-1`` values were ignored in the labels. idx_hungarian (Tuple[int, int]): 'Hungarian Indices'. Indices of the best matched sets. """ assert len(labels_test) == len(labels_true), 'RH ERROR: labels_test and labels_true must be the same length.' labels_test = np.array(labels_test, dtype=int) labels_true = np.array(labels_true, dtype=int) ## convert labels to boolean uniques_test, uniques_true = np.unique(labels_test), np.unique(labels_true) bool_test = np.stack([labels_test==l for l in uniques_test], axis=0).astype(np.float32) bool_true = np.stack([labels_true==l for l in uniques_true], axis=0).astype(np.float32) ## Hungarian matching score if ignore_negOne: bool_test[uniques_test == -1, :] = 0.0 bool_true[uniques_true == -1, :] = 0.0 if bool_test.shape[0] < bool_true.shape[0]: bool_test = np.concatenate((bool_test, np.zeros((bool_true.shape[0] - bool_test.shape[0], bool_true.shape[1])))) ## compute confusion / correlation matrix with np.errstate(divide='ignore', invalid='ignore'): cc = np.nan_to_num((bool_true @ bool_test.T) / (bool_true.sum(axis=1)[:, None]), nan=0.0, posinf=0.0, neginf=0.0) ## normalize by the number of elements in each set ## find hungarian assignment matching indices hi = scipy.optimize.linear_sum_assignment(cost_matrix=cc, maximize=True) ## extract correlation scores of matches cc_matched = cc[hi[0], hi[1]] label_weights = bool_true.sum(axis=1)[hi[0]] ## reweighting vector is the number of elements in each true set label_weights_norm = label_weights / label_weights.sum() ## normalize the weights hungarian_match_score_weighted_partial = cc_matched @ label_weights_norm hungarian_match_score_unweighted_partial = cc_matched.mean() hungarian_match_score_weighted_perfect = (cc_matched > thresh_perfect).astype(float) @ label_weights_norm hungarian_match_score_unweighted_perfect = (cc_matched > thresh_perfect).mean() ## SKLEARN METRICS ### First change all -1 values to a unique value that is not present in the labels uniques = np.unique(np.concatenate((labels_true, labels_test))) ### Make a bunch of values greater than the maximum value in the labels n_minusOne_true = np.sum(labels_true == -1) n_minusOne_test = np.sum(labels_test == -1) labels_true_sk, labels_test_sk = labels_true.copy(), labels_test.copy() labels_true_sk[labels_true == -1] = uniques.max() + np.arange(1, n_minusOne_true + 1) labels_test_sk[labels_test == -1] = uniques.max() + np.arange(n_minusOne_true + 1, n_minusOne_true + n_minusOne_test + 1) ## compute adjusted rand score score_rand = sklearn.metrics.adjusted_rand_score(labels_true=labels_true_sk, labels_pred=labels_test_sk) ## compute fowlkes mallows score score_fowlkes_mallows = sklearn.metrics.fowlkes_mallows_score(labels_true=labels_true_sk, labels_pred=labels_test_sk) ## compute adjusted mutual info score score_mutual_info = sklearn.metrics.adjusted_mutual_info_score(labels_true=labels_true_sk, labels_pred=labels_test_sk) ## compute homogeneity, completeness, and v-measure homogeneity, completeness, v_measure = sklearn.metrics.homogeneity_completeness_v_measure(labels_true=labels_true_sk, labels_pred=labels_test_sk, beta=1.0) ## compute pair confusion matrix pair_confusion = sklearn.metrics.cluster.pair_confusion_matrix(labels_true=labels_true_sk, labels_pred=labels_test_sk).tolist() if pair_confusion is not None: p = np.array(pair_confusion) TP, TN, FP, FN = p[1, 1], p[0, 0], p[0, 1], p[1, 0] pc_accuracy = (TP + TN) / (TP + TN + FP + FN) N_all = TN + FP P_all = TP + FN pc_precision = TP / (TP + FP) if (TP + FP) > 0 else 0 pc_recall = TP / (TP + FN) if (TP + FN) > 0 else 0 pc_f1 = (2 * pc_precision * pc_recall) / (pc_precision + pc_recall) if (pc_precision + pc_recall) > 0 else 0 pc_accuracy_norm = (TP/P_all + TN/N_all) / (TP/P_all + TN/N_all + FP/N_all + FN/P_all) else: pc_accuracy = None pc_accuracy_norm = None out = { 'hungarian_score_weighted_partial': float(hungarian_match_score_weighted_partial), 'hungarian_score_weighted_perfect': float(hungarian_match_score_weighted_perfect), 'hungarian_score_unweighted_partial': float(hungarian_match_score_unweighted_partial), 'hungarian_score_unweighted_perfect': float(hungarian_match_score_unweighted_perfect), 'adj_rand_score': float(score_rand), 'fowlkes_mallows_score': float(score_fowlkes_mallows), 'adj_mutual_info_score': float(score_mutual_info), 'homogeneity_score': float(homogeneity), 'completeness_score': float(completeness), 'v_measure_score': float(v_measure), 'pair_confusion_matrix': pair_confusion, 'pair_confusion_accuracy_score': float(pc_accuracy), 'pair_confusion_accuracy_norm_score': float(pc_accuracy_norm), 'pair_confusion_precision_score': float(pc_precision) if pc_precision is not None else None, 'pair_confusion_recall_score': float(pc_recall) if pc_recall is not None else None, 'pair_confusion_f1_score': float(pc_f1) if pc_f1 is not None else None, 'labels_test': labels_test.tolist(), 'ignore_negOne': ignore_negOne, 'idx_hungarian': hi, } return out
[docs] def cluster_quality_metrics( sim: Union[np.ndarray, scipy.sparse.csr_array], labels: np.ndarray, ) -> Tuple: """ Computes the cluster quality metrics for a clustering solution including intra-cluster mean, minimum, maximum similarity, and cluster silhouette score. RH 2023 Args: sim (Union[np.ndarray, scipy.sparse.csr_array]): Similarity matrix. (shape: *(n_roi, n_roi)*) It can be obtained using `_, sConj, _,_,_,_ = clusterer.make_conjunctive_similarity_matrix()`. labels (np.ndarray): Cluster labels. (shape: *(n_roi,)*) Returns: (tuple): tuple containing: cs_intra_means (np.ndarray): Intra-cluster mean similarity. (shape: *(n_clusters,)*) cs_intra_mins (np.ndarray): Intra-cluster minimum similarity. (shape: *(n_clusters,)*) cs_intra_maxs (np.ndarray): Intra-cluster maximum similarity. (shape: *(n_clusters,)*) cs_sil (np.ndarray): Cluster silhouette score. (shape: *(n_clusters,)*) Describes intra_mean - inter_max_of_maxes """ import sparse labels_unique, cs_mean, cs_max, cs_min = helpers.compute_cluster_similarity_matrices(sim, labels, verbose=True) fn_sil_score = lambda intra, inter: (intra - inter) / np.maximum(intra, inter) eye_inv = 1 - sparse.eye(cs_max.shape[0]) cs_intra_means = cs_mean.diagonal() cs_inter_maxOfMaxs = (eye_inv * cs_max).max(0) cs_sil = fn_sil_score(cs_intra_means, cs_inter_maxOfMaxs) cs_intra_mins = cs_min.diagonal() cs_intra_maxs = cs_max.diagonal() return labels_unique, cs_intra_means, cs_intra_mins, cs_intra_maxs, cs_sil
[docs] def make_label_variants( labels: np.ndarray, n_roi_bySession: np.ndarray, ) -> Tuple: """ Creates convenient variants of label arrays. RH 2023 Args: labels (np.ndarray): Cluster integer labels. (shape: *(n_roi,)*) n_roi_bySession (np.ndarray): Number of ROIs in each session. Returns: (tuple): tuple containing: labels_squeezed (np.ndarray): Cluster labels squeezed into a continuous range starting from 0. labels_bySession (List[np.ndarray]): List of label arrays split by session. labels_bool (scipy.sparse.csr_array): Sparse boolean matrix representation of labels. labels_bool_bySession (List[scipy.sparse.csr_array]): List of sparse boolean matrix representations of labels split by session. labels_dict (Dict[int, np.ndarray]): Dictionary mapping unique labels to their locations in the labels array. """ import scipy.sparse ## assert that labels is a 1D np.array or list of numbers if isinstance(labels, list): labels = np.array(labels, dtype=np.int64) elif isinstance(labels, np.ndarray): labels = labels.astype(np.int64) else: raise TypeError('RH ERROR: labels must be a 1D np.array or list of numbers.') assert labels.ndim == 1, 'RH ERROR: labels must be a 1D np.array or list of numbers.' ## assert that n_roi_bySession is a 1D np.array or list of numbers if isinstance(n_roi_bySession, list): n_roi_bySession = np.array(n_roi_bySession, dtype=np.int64) elif isinstance(n_roi_bySession, np.ndarray): n_roi_bySession = n_roi_bySession.astype(np.int64) else: raise TypeError('RH ERROR: n_roi_bySession must be a 1D np.array or list of numbers.') assert n_roi_bySession.ndim == 1, 'RH ERROR: n_roi_bySession must be a 1D np.array or list of numbers.' ## assert that the number of labels adds up to the number of ROIs n_roi_total = n_roi_bySession.sum() n_roi_cumsum = np.concatenate([[0], n_roi_bySession.cumsum()]) assert labels.shape[0] == n_roi_bySession.sum(), 'RH ERROR: the number of labels must add up to the number of ROIs.' ## make session_bool session_bool = util.make_session_bool(n_roi_bySession) ## make variants labels_squeezed = helpers.squeeze_integers(labels) labels_bySession = [labels_squeezed[idx] for idx in session_bool.T] labels_bool = scipy.sparse.vstack([scipy.sparse.csr_array(labels_squeezed==u) for u in np.sort(np.unique(labels_squeezed))]).T.tocsr() labels_bool_bySession = [labels_bool[idx] for idx in session_bool.T] labels_dict = {u: np.where(labels_squeezed==u)[0] for u in np.unique(labels_squeezed)} ## testing assert np.allclose(np.concatenate(labels_bySession), labels_squeezed) assert np.allclose(labels_bool.nonzero()[1] - 1, labels_squeezed) assert np.all([np.allclose(np.where(labels_squeezed==u)[0], ldu) for u, ldu in labels_dict.items()]) ## Convert everything to native python types for JSON compatibility labels_squeezed = util.JSON_List([int(u) for u in labels_squeezed]) labels_bySession = util.JSON_List([[int(u) for u in l] for l in labels_bySession]) labels_dict = util.JSON_Dict({str(k): [int(v_i) for v_i in v] for k, v in labels_dict.items()}) ## Make keys strings for JSON compatibility return labels_squeezed, labels_bySession, labels_bool, labels_bool_bySession, labels_dict
[docs] def plot_quality_metrics(quality_metrics: dict, labels: Union[np.ndarray, list], n_sessions: int) -> None: fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,7)) axs[0,0].hist(quality_metrics['cluster_silhouette'], 50); axs[0,0].set_xlabel('cluster_silhouette'); axs[0,0].set_ylabel('cluster counts'); axs[0,1].hist(quality_metrics['cluster_intra_means'], 50); axs[0,1].set_xlabel('cluster_intra_means'); axs[0,1].set_ylabel('cluster counts'); axs[1,0].hist(quality_metrics['sample_silhouette'], 50); axs[1,0].set_xlabel('sample_silhouette score'); axs[1,0].set_ylabel('roi sample counts'); u, c = np.unique((v:=np.array(labels))[v!=-1], return_counts=True) n_sesh = np.bincount(c) axs[1,1].bar(np.arange(len(n_sesh)), n_sesh); axs[1,1].set_xlabel('n_sessions') axs[1,1].set_ylabel('cluster counts'); # Make the title include the number of excluded (label==-1) ROIs fig.suptitle(f'Quality metrics n_excluded: {np.sum(labels==-1)}, n_included: {np.sum(labels!=-1)}, n_total: {len(labels)}, n_clusters: {len(np.unique(labels[labels!=-1]))}, n_sessions: {n_sessions}') return fig, axs