import warnings
from typing import Union, Tuple, List, Dict, Optional, Any
import numpy as np
import scipy
import scipy.optimize
import scipy.sparse
import sklearn
import sklearn.isotonic
import matplotlib.pyplot as plt
import torch
from tqdm.auto import tqdm
from .. import helpers, util
from .similarity_graph import SimilarityMetric
[docs]
class Clusterer(util.ROICaT_Module):
"""
Class for clustering algorithms. Performs:
* Optimal mixing and pruning of similarity matrices:
* self.find_optimal_parameters_for_pruning()
* self.make_pruned_similarity_graphs()
* Clustering:
* self.fit(): Which uses a modified HDBSCAN
* self.fit_sequentialHungarian: Which uses a method similar to
CaImAn's clustering method.
* Quality control:
* self.compute_cluster_quality_metrics()
Initialization ingests and stores similarity matrices. RH 2023 / 2025
Args:
similarities (Dict[str, scipy.sparse.csr_array]):
Dict mapping metric name to sparse similarity matrix. All
matrices must share the same nonzero pattern (same nnz).
Example: ``{'sf': csr_array, 'nn': csr_array, 'swt': csr_array}``.
metric_configs (List[SimilarityMetric]):
List of ``SimilarityMetric`` dataclass instances describing each
metric's optimization behavior (sparsity source, sigmoid, power).
s_sesh (scipy.sparse.csr_array):
Inter-session mask. Shape: *(n_rois, n_rois)*. Boolean, with 1s
where the two ROIs are from different sessions.
n_bins (Optional[int]):
Number of bins to use for the pairwise similarity distribution. If
``None``, then a heuristic is used to estimate the value based on
the number of nonzero pairs. (Default is ``None``)
smoothing_window_bins (Optional[int]):
Number of bins to use when smoothing the distribution. If
``None``, then a heuristic is used. (Default is ``None``)
session_bool (Optional[np.ndarray]):
Boolean array indicating which ROIs belong to which session.
Shape: *(n_rois, n_sessions)*. (Default is ``None``)
verbose (bool):
Specifies whether to print out information about the clustering
process. (Default is ``True``)
Attributes:
similarities (Dict[str, scipy.sparse.csr_array]):
Dict of similarity matrices keyed by metric name.
s_sesh (scipy.sparse.csr_array):
The inter-session similarity matrix. Shape: *(n_rois, n_rois)*.
s_sesh_inv (scipy.sparse.csr_array):
Intra-session mask (True where ROIs are from the SAME session).
Shape: *(n_rois, n_rois)*.
n_bins (Optional[int]):
Number of bins to use for the pairwise similarity distribution.
smooth_window (Optional[int]):
Number of bins to use when smoothing the distribution.
verbose (bool):
Specifies how much information to print out: \n
* 0/False: Warnings only
* 1/True: Basic info, progress bar
* 2: All info
"""
def __init__(
self,
similarities: Dict[str, scipy.sparse.csr_array],
metric_configs: List[SimilarityMetric],
s_sesh: scipy.sparse.csr_array,
n_bins: Optional[int] = None,
smoothing_window_bins: Optional[int] = None,
session_bool: Optional[np.ndarray] = None,
verbose: bool = True,
):
"""
Initializes the Clusterer with similarity matrices and metric configs.
Args:
similarities (Dict[str, scipy.sparse.csr_array]):
Dict mapping metric name to similarity matrix. All matrices
must share the same nonzero pattern (same nnz). Example:
``{'sf': csr_array, 'nn': csr_array, 'swt': csr_array}``.
metric_configs (List[SimilarityMetric]):
List of metric configurations describing each metric's
role in optimization (sparsity source, sigmoid, power, etc.).
s_sesh (scipy.sparse.csr_array):
Inter-session mask. Shape: *(n_rois, n_rois)*. Boolean,
with 1s where two ROIs are from different sessions.
n_bins (Optional[int]):
Number of bins for pairwise similarity distributions.
If ``None``, heuristic based on nnz. (Default is ``None``)
smoothing_window_bins (Optional[int]):
Smoothing window for distributions. If ``None``, heuristic.
(Default is ``None``)
session_bool (Optional[np.ndarray]):
Boolean array indicating which ROIs belong to which session.
Shape: *(n_rois, n_sessions)*. (Default is ``None``)
verbose (bool):
Verbosity level. (Default is ``True``)
"""
super().__init__()
## Store parameter (but not data) args as attributes
self.params['__init__'] = self._locals_to_params(
locals_dict=locals(),
keys=['n_bins', 'smoothing_window_bins', 'verbose'],
)
## Store similarities dict and metric configs
self.similarities = similarities
## Store metric configs directly as SimilarityMetric objects keyed by name.
## RichFile_ROICaT has a registered type handler for SimilarityMetric
## that serializes them as JSON via to_dict()/from_dict().
self._metric_configs_stored = {m.name: m for m in metric_configs}
## Identify sparsity source
sparsity_sources = [name for name, cfg in self._metric_configs.items() if cfg.is_sparsity_source]
assert len(sparsity_sources) > 0, "At least one metric must have is_sparsity_source=True"
self._sparsity_name = sparsity_sources[0]
self._s_sparsity = self.similarities[self._sparsity_name]
## Validate all similarities have same nnz
nnz_values = {name: s.nnz for name, s in self.similarities.items()}
first_nnz = next(iter(nnz_values.values()))
assert all(v == first_nnz for v in nnz_values.values()), (
f"All similarity matrices must have same nnz. Got: {nnz_values}"
)
## Build session masks. s_sesh marks inter-session pairs (different
## sessions). s_sesh_inv marks intra-session pairs (same session)
## within the sparsity pattern — used by _precompute_intra_mask.
self.s_sesh = s_sesh ## shape: (n_rois, n_rois), sparse bool
self.s_sesh_inv = (self._s_sparsity != 0).astype(bool) ## all nonzero pairs
self.s_sesh_inv[self.s_sesh.astype(bool)] = False ## remove inter-session → leaves intra only
self.s_sesh_inv.eliminate_zeros()
## Zero out diagonal of s_sesh (self-pairs are neither inter nor intra)
self.s_sesh = self.s_sesh.tolil()
self.s_sesh[range(self.s_sesh.shape[0]), range(self.s_sesh.shape[1])] = 0
self.s_sesh = self.s_sesh.tocsr()
self._verbose = verbose
self.n_bins = max(min(first_nnz // 10000, 200), 20) if n_bins is None else n_bins
self.smooth_window = helpers.make_odd(self.n_bins // 10, mode='up') if smoothing_window_bins is None else smoothing_window_bins
self._session_bool = session_bool
def __repr__(self):
has_sim = hasattr(self, 'similarities') and self.similarities is not None
if has_sim:
first_mat = next(iter(self.similarities.values()))
n_roi = first_mat.shape[0]
nnz = first_mat.nnz
else:
n_roi, nnz = 0, 0
has_labels = hasattr(self, 'labels') and self.labels is not None
n_clusters = len(set(self.labels) - {-1}) if has_labels else 0
return (
f"Clusterer(n_roi={n_roi}, nnz={nnz}, "
f"metrics={list(self.similarities.keys()) if has_sim else []}, "
f"n_clusters={n_clusters if has_labels else 'not fitted'})"
)
@property
def _metric_configs(self) -> Dict[str, SimilarityMetric]:
"""Return stored SimilarityMetric instances keyed by name."""
stored = self._metric_configs_stored
## After legacy pickle load with dicts, reconstruct
if len(stored) > 0 and isinstance(next(iter(stored.values())), dict):
return {name: SimilarityMetric.from_dict(d) for name, d in stored.items()}
return stored
[docs]
def find_optimal_parameters_for_pruning(
self,
bounds_findParameters: Optional[Dict[str, List[float]]] = None,
de_kwargs: Dict[str, Any] = {
'maxiter': 100,
'tol': 1e-6,
'popsize': 15,
'mutation': (0.5, 1.5),
'recombination': 0.7,
'polish': True,
},
n_bins: Optional[int] = None,
smoothing_window_bins: Optional[int] = None,
subsample_pairs: Optional[int] = None,
seed: Optional[int] = None,
) -> Dict:
"""
Find optimal mixing parameters for pruning the similarity graph.
Two-stage approach:
1. **Naive Bayes calibration**: For each similarity feature, estimates
``P(same | s_k)`` from histogram subtraction. The resulting
per-feature calibration curves are used to analytically estimate
optimal sigmoid parameters ``(mu, b)`` via Fisher's linear
discriminant.
2. **Differential evolution**: With sigmoid parameters frozen from
stage 1, optimizes the remaining parameters (one ``power_<name>``
per metric with ``optimize_power=True``, plus ``p_norm``) by
minimizing the histogram overlap loss.
This method replaces the original Optuna TPE search (see
:meth:`_find_optimal_parameters_for_pruning_optuna` in the legacy
section). The two-stage approach achieves better separation quality
(lower histogram overlap) on typical datasets.
RH 2023 / 2025
Args:
bounds_findParameters (Dict[str, List[float]]):
Bounds for the optimized parameters. Keys are
``power_<name>`` for each metric with ``optimize_power=True``,
plus ``p_norm``. Auto-constructed from metric configs if
``None``.
de_kwargs (Dict[str, Any]):
Keyword arguments for
``scipy.optimize.differential_evolution``: \n
* ``maxiter`` (int): Maximum number of DE generations.
* ``tol`` (float): Convergence tolerance on the loss.
* ``popsize`` (int): Population size multiplier
(actual population = ``popsize * n_params``).
* ``mutation`` (Tuple[float, float]): Differential
weight range ``(min, max)`` for dithering.
* ``recombination`` (float): Crossover probability
in ``[0, 1]``.
* ``polish`` (bool): If ``True``, run L-BFGS-B from
the best DE solution. Often has no effect on
piecewise-constant histogram loss.
n_bins (Optional[int]):
Overwrites ``n_bins`` from ``__init__``.
smoothing_window_bins (Optional[int]):
Overwrites ``smoothing_window_bins`` from ``__init__``.
subsample_pairs (Optional[int]):
If not ``None``, subsample this many pairs for histogram
loss evaluation. Maintains intra/inter ratio. If ``None``,
auto-computed based on pair counts.
seed (Optional[int]):
Random seed for reproducibility.
Returns:
(Dict):
kwargs_makeConjunctiveDistanceMatrix_best (Dict):
Optimal parameters for
:meth:`make_conjunctive_distance_matrix`.
"""
## Store parameter (but not data) args as attributes
self.params['find_optimal_parameters_for_pruning'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'bounds_findParameters', 'de_kwargs', 'n_bins',
'smoothing_window_bins', 'subsample_pairs', 'seed',
],
)
## Auto-construct bounds from metric configs if not provided
if bounds_findParameters is None:
bounds_findParameters = {}
for name, cfg in self._metric_configs.items():
if cfg.optimize_power:
bounds_findParameters[f'power_{name}'] = list(cfg.power_bounds)
bounds_findParameters['p_norm'] = [-5, -0.1]
## NB calibration → Fisher sigmoid estimation → N-param DE.
return self._find_optimal_parameters_DE(
bounds_findParameters=bounds_findParameters,
de_kwargs=de_kwargs,
n_bins=n_bins,
smoothing_window_bins=smoothing_window_bins,
subsample_pairs=subsample_pairs,
seed=seed,
freeze_sigmoid=True,
)
####################################################################
## Differential evolution parameter optimization
####################################################################
def _precompute_intra_mask(self) -> np.ndarray:
"""
Build a boolean mask of shape ``(nnz,)`` indicating which nonzero
entries in the sparsity source matrix correspond to intra-session
(known-different) ROI pairs.
Uses an index-mapping trick: assigns each nonzero entry a unique 1-based
index, multiplies by the inverse session matrix to isolate intra-session
entries, then reads back which indices survived.
The result is stored as a numpy bool array (so that serialization via
``serializable_dict`` preserves it). Call sites that need a torch tensor
should wrap with ``torch.as_tensor(self._intra_mask)``.
RH 2025
Returns:
(np.ndarray):
intra_mask (np.ndarray):
Boolean numpy array, shape ``(self._s_sparsity.nnz,)``.
"""
## Map each nonzero entry to a unique 1-based index
idx_mat = self._s_sparsity.copy().astype(np.float64)
idx_mat.data = np.arange(1, self._s_sparsity.nnz + 1, dtype=np.float64)
## Elementwise multiply with s_sesh_inv to keep only intra-session entries
masked = idx_mat.multiply(self.s_sesh_inv.astype(np.float64))
masked.eliminate_zeros()
## Convert back to 0-based indices and build boolean mask
intra_indices = (masked.data - 1).astype(np.int64)
mask = np.zeros(self._s_sparsity.nnz, dtype=bool)
mask[intra_indices] = True
## Store as a numpy bool array; callers convert with torch.as_tensor()
self._intra_mask = mask
return self._intra_mask
def _subsample_pairs(
self,
n_subsample: int,
intra_mask: torch.Tensor,
seed: Optional[int] = None,
) -> torch.Tensor:
"""
Subsample pair indices while preserving the intra/inter-session
ratio. Returns indices into the full nnz-length arrays.
Intra-session indices come first in the returned tensor, so the
intra mask for the subsample is simply ``True`` for the first
``n_sample_intra`` entries.
RH 2025
Args:
n_subsample (int):
Target number of pairs to sample.
intra_mask (torch.Tensor):
Boolean tensor, shape ``(nnz,)``. ``True`` for
intra-session (known-different) pairs.
seed (Optional[int]):
Random seed for reproducibility.
Returns:
(torch.Tensor):
sample_idx (torch.Tensor):
1D int64 tensor of sampled pair indices, length
``<= n_subsample``. Intra-session pairs come first.
"""
## Accept numpy or torch; convert to torch for all internal operations.
intra_mask = torch.as_tensor(intra_mask)
nnz = intra_mask.shape[0]
n_intra = int(intra_mask.sum().item())
n_inter = nnz - n_intra
frac = n_subsample / nnz
n_sample_intra = max(int(n_intra * frac), 100)
n_sample_inter = max(int(n_inter * frac), 100)
intra_idx = torch.where(intra_mask)[0]
inter_idx = torch.where(~intra_mask)[0]
rng = torch.Generator()
rng.manual_seed(seed if seed is not None else 42)
perm_intra = torch.randperm(n_intra, generator=rng)[:n_sample_intra]
perm_inter = torch.randperm(n_inter, generator=rng)[:n_sample_inter]
## Intra first, then inter — intra_mask for subsample is True for
## the first n_sample_intra entries
return torch.cat([intra_idx[perm_intra], inter_idx[perm_inter]])
def _find_optimal_parameters_DE(
self,
bounds_findParameters: Optional[Dict[str, List[float]]] = None,
de_kwargs: Dict[str, Any] = {
'maxiter': 100,
'tol': 1e-6,
'popsize': 15,
'mutation': (0.5, 1.5),
'recombination': 0.7,
'polish': True,
},
n_bins: Optional[int] = None,
smoothing_window_bins: Optional[int] = None,
subsample_pairs: Optional[int] = None,
seed: Optional[int] = None,
freeze_sigmoid: bool = True,
) -> Dict:
"""
Find optimal mixing parameters using scipy differential evolution.
When ``freeze_sigmoid=True`` (default), sigmoid parameters ``(mu, b)``
are estimated from NB calibration curves via Fisher's linear
discriminant and held fixed. The search is then over
``power_<name>`` for each metric with ``optimize_power=True``, plus
``p_norm``. When ``False``, sigmoid params are also optimized.
The inner loop operates entirely on precomputed torch tensors — no
scipy sparse operations per evaluation. When subsampling is active,
the subsample is redrawn each DE generation to reduce overfitting
to a specific pair subset.
RH 2025
Args:
bounds_findParameters (Dict[str, List[float]]):
Bounds for each parameter, keyed by ``power_<name>`` and
``p_norm``. Auto-constructed from metric configs if ``None``.
When ``False``, all 7 keys are needed.
de_kwargs (Dict[str, Any]):
Keyword arguments for
``scipy.optimize.differential_evolution``: \n
* ``maxiter`` (int): Maximum number of DE generations.
* ``tol`` (float): Convergence tolerance on the loss.
* ``popsize`` (int): Population size multiplier
(actual population = ``popsize * n_params``).
* ``mutation`` (Tuple[float, float]): Differential
weight range ``(min, max)`` for dithering.
* ``recombination`` (float): Crossover probability
in ``[0, 1]``.
* ``polish`` (bool): If ``True``, run L-BFGS-B from
the best DE solution. Often has no effect on
piecewise-constant histogram loss.
n_bins (Optional[int]):
Overwrites ``n_bins`` from __init__.
smoothing_window_bins (Optional[int]):
Overwrites ``smoothing_window_bins`` from __init__.
subsample_pairs (Optional[int]):
If not ``None``, subsample this many pairs for histogram
loss. Maintains intra/inter ratio. If ``None``,
auto-computed: subsamples to 1.1M (100k intra + 1M inter)
when there are enough pairs, otherwise uses all pairs.
seed (Optional[int]):
Random seed for reproducibility.
freeze_sigmoid (bool):
If ``True``, fix sigmoid params from NB calibration,
reducing DE to 3 parameters. If ``False``, optimize
all 7 parameters jointly.
Returns:
(Dict):
kwargs_makeConjunctiveDistanceMatrix_best (Dict):
Optimal parameters for
:meth:`make_conjunctive_distance_matrix`.
"""
## Store parameter (but not data) args as attributes
self.params['_find_optimal_parameters_DE'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'bounds_findParameters', 'de_kwargs', 'n_bins',
'smoothing_window_bins', 'subsample_pairs', 'seed',
'freeze_sigmoid',
],
)
self.n_bins = self.n_bins if n_bins is None else n_bins
self.smooth_window = self.smooth_window if smoothing_window_bins is None else smoothing_window_bins
## Auto-construct bounds from metric configs if not provided
if bounds_findParameters is None:
bounds_findParameters = {}
for name, cfg in self._metric_configs.items():
if cfg.optimize_power:
bounds_findParameters[f'power_{name}'] = list(cfg.power_bounds)
bounds_findParameters['p_norm'] = [-5, -0.1]
## Add sigmoid bounds for unfrozen case
for name, cfg in self._metric_configs.items():
if cfg.optimize_sigmoid:
bounds_findParameters[f'sig_{name}_kwargs_mu'] = [0., 1.0]
bounds_findParameters[f'sig_{name}_kwargs_b'] = [0.1, 1.5]
self.bounds_findParameters = bounds_findParameters
self._seed = seed
################################################################
## Auto-compute subsample size if not specified
################################################################
if not hasattr(self, '_intra_mask') or self._intra_mask is None:
self._precompute_intra_mask()
if subsample_pairs is None:
n_intra = int(self._intra_mask.sum())
n_inter = self._s_sparsity.nnz - n_intra
## Subsample only if we have enough pairs to meet minimums
min_intra = 100_000
min_inter = 1_000_000
if n_intra >= min_intra and n_inter >= min_inter:
subsample_pairs = min_intra + min_inter
## Otherwise use all pairs
################################################################
## Determine parameter layout: 3-param (frozen) or 7-param (full)
################################################################
_frozen_sig = None
if freeze_sigmoid:
if not hasattr(self, 'calibrations_naive_bayes') or self.calibrations_naive_bayes is None:
self.make_naive_bayes_distance_matrix()
sig_params = self._estimate_sigmoid_params()
_frozen_sig = sig_params ## Dict[metric_name, {'mu': float, 'b': float}]
if self._verbose:
parts = [f'{n}(mu={p["mu"]:.3f}, b={p["b"]:.1f})' for n, p in _frozen_sig.items()]
print(f' Freezing sigmoid: {", ".join(parts)}')
## Cache metric configs once. The property reconstructs SimilarityMetric
## objects from dicts on every access; caching avoids that overhead
## inside the objective function which runs thousands of times.
_cached_configs = self._metric_configs ## Dict[str, SimilarityMetric]
## Build parameter layout: [power_<m1>, power_<m2>, ..., p_norm]
## This list is immutable and referenced by both bounds and objective
self._de_param_layout = []
for name, cfg in _cached_configs.items():
if cfg.optimize_power:
self._de_param_layout.append(('power', name))
self._de_param_layout.append(('p_norm', None))
## If not freezing sigmoid, add sigmoid params too
if not freeze_sigmoid:
for name, cfg in _cached_configs.items():
if cfg.optimize_sigmoid:
self._de_param_layout.append(('sig_mu', name))
self._de_param_layout.append(('sig_b', name))
param_keys = []
for ptype, pname in self._de_param_layout:
if ptype == 'power':
param_keys.append(f'power_{pname}')
elif ptype == 'p_norm':
param_keys.append('p_norm')
elif ptype == 'sig_mu':
param_keys.append(f'sig_{pname}_kwargs_mu')
elif ptype == 'sig_b':
param_keys.append(f'sig_{pname}_kwargs_b')
scipy_bounds = [
tuple(bounds_findParameters[k])
for k in param_keys if k in bounds_findParameters
] ## list of (lo, hi) tuples, one per DE dimension
################################################################
## Precompute tensors — eliminates all sparse operations from
## the inner loop.
################################################################
## Extract .data arrays from sparse matrices into contiguous tensors.
## Each tensor has shape (nnz,) — one value per nonzero pair.
tensors_full = {
name: torch.as_tensor(
np.ascontiguousarray(sim.data), dtype=torch.float32,
).clone()
for name, sim in self.similarities.items()
} ## Dict[str, Tensor(nnz,)]
## Boolean mask for intra-session (known-different) pairs
if not hasattr(self, '_intra_mask') or self._intra_mask is None:
self._precompute_intra_mask()
intra_mask_full = torch.as_tensor(self._intra_mask) ## shape (nnz,), bool
################################################################
## Helper: build working tensors (with optional subsampling)
################################################################
def _build_working_tensors(resample_seed: Optional[int]):
"""Return (tensors_dict, intra_mask) after optional subsampling."""
nnz_full = next(iter(tensors_full.values())).shape[0]
if subsample_pairs is not None and subsample_pairs < nnz_full:
sidx = self._subsample_pairs(
n_subsample=subsample_pairs,
intra_mask=intra_mask_full,
seed=resample_seed if resample_seed is not None else 77777,
)
## Intra pairs come first in sidx. Compute actual intra count
n_intra_full = int(intra_mask_full.sum().item())
frac = subsample_pairs / nnz_full
n_si = min(max(int(n_intra_full * frac), 100), n_intra_full)
im = torch.zeros(sidx.shape[0], dtype=torch.bool)
im[:n_si] = True
return {name: t[sidx] for name, t in tensors_full.items()}, im
else:
return dict(tensors_full), intra_mask_full
## Initial working set
working_tensors, intra_mask = _build_working_tensors(
resample_seed=(seed + 77777) if seed is not None else 77777,
)
first_t = next(iter(working_tensors.values()))
print(
f' Working set: {first_t.shape[0]} pairs '
f'({int(intra_mask.sum().item())} intra, '
f'{first_t.shape[0] - int(intra_mask.sum().item())} inter)'
) if self._verbose and subsample_pairs is not None else None
################################################################
## Build shared histogram infrastructure from current tensors
################################################################
n_bins_val = self.n_bins
edges = torch.linspace(0, 1, n_bins_val + 1, dtype=torch.float32)
smooth_window = helpers.make_odd(n_bins_val // 10, mode='up')
smoother = helpers.Convolver_1d(
kernel=torch.ones(smooth_window),
length_x=n_bins_val,
pad_mode='same',
correct_edge_effects=True,
device='cpu',
)
## Mutable containers so resample callback can update them in-place
_state = {
'tensors': working_tensors,
'intra_mask': intra_mask,
'generation': 0,
}
_state['intra_indices'] = torch.where(_state['intra_mask'])[0]
_state['n_all'] = first_t.shape[0]
_state['n_intra'] = int(_state['intra_mask'].sum().item())
_state['scale_factor'] = _state['n_all'] / max(_state['n_intra'], 1)
################################################################
## Resample callback — redraws subsampled tensors each generation
################################################################
_generation_counter = [0]
def _resample_callback(xk, convergence=None):
"""Redraw subsample at the start of each generation."""
gen = _generation_counter[0]
_generation_counter[0] += 1
new_seed = (seed + gen * 1000 + 99999) if seed is not None else (gen * 1000 + 99999)
new_tensors, im_new = _build_working_tensors(resample_seed=new_seed)
_state['tensors'] = new_tensors
_state['intra_mask'] = im_new
_state['intra_indices'] = torch.where(im_new)[0]
first_new = next(iter(new_tensors.values()))
_state['n_all'] = first_new.shape[0]
_state['n_intra'] = int(im_new.sum().item())
_state['scale_factor'] = _state['n_all'] / max(_state['n_intra'], 1)
################################################################
## Scalar objective — evaluates one parameter vector at a time
################################################################
def objective_scalar(x):
ii = _state['intra_indices']
sc = _state['scale_factor']
## Unpack parameter vector using the frozen layout
param_idx = 0
sig_params_live = {} ## for unfrozen sigmoid params
## Parse sigmoid params from x. Must advance param_idx for
## every entry in the layout (power and p_norm included) so
## that sigmoid entries read from the correct positions.
for ptype, pname in self._de_param_layout:
if ptype in ('power', 'p_norm'):
param_idx += 1 ## skip; handled in the loop below
elif ptype == 'sig_mu':
if pname not in sig_params_live:
sig_params_live[pname] = {}
sig_params_live[pname]['mu'] = float(x[param_idx])
param_idx += 1
elif ptype == 'sig_b':
sig_params_live[pname]['b'] = float(x[param_idx])
param_idx += 1
## Build activated list for all metrics
activated = []
param_idx = 0
for name, cfg in _cached_configs.items():
s_w = _state['tensors'][name]
## Apply sigmoid if configured
if cfg.optimize_sigmoid:
if _frozen_sig is not None and name in _frozen_sig:
mu = _frozen_sig[name]['mu']
b = _frozen_sig[name]['b']
elif name in sig_params_live:
mu = sig_params_live[name]['mu']
b = sig_params_live[name]['b']
else:
mu, b = 0.0, 1.0
s_w = torch.sigmoid(b * (s_w - mu))
## Apply power if optimized
if cfg.optimize_power:
power = float(x[param_idx])
param_idx += 1
s_w = torch.clamp(s_w, min=1e-8).pow(power)
else:
s_w = torch.clamp(s_w, min=1e-8)
activated.append(s_w)
## p-norm is the parameter after all power params
p_idx = sum(1 for pt, _ in self._de_param_layout if pt == 'power')
p = float(x[p_idx])
p = p if abs(p) > 1e-9 else 1e-9
## p-norm mixing: distance = 1 - (mean(s_k^p))^(1/p)
N = len(activated)
running_sum = torch.zeros_like(activated[0])
for a in activated:
running_sum += a.pow(p)
dist = 1.0 - (running_sum / N).pow(1.0 / p)
## Histogram overlap loss via shared helper
loss, _, _ = self._compute_histogram_overlap(
distances=dist,
intra_indices=ii,
edges=edges,
smoother=smoother,
scale_factor=sc,
)
return loss
################################################################
## Configure and run differential evolution
################################################################
print('Finding mixing parameters using differential evolution...') if self._verbose else None
de_kwargs_use = dict(de_kwargs)
nnz_full = next(iter(tensors_full.values())).shape[0]
## Always resample each generation when subsampling
if subsample_pairs is not None and subsample_pairs < nnz_full:
existing_cb = de_kwargs_use.pop('callback', None)
def _combined_callback(xk, convergence=None):
_resample_callback(xk, convergence)
if existing_cb is not None:
return existing_cb(xk, convergence) ## propagate stop signal
de_kwargs_use['callback'] = _combined_callback
## Coerce seed to int for scipy DE; leave None for random behavior
de_seed = int(seed) if seed is not None else None
self._de_result = scipy.optimize.differential_evolution(
func=objective_scalar,
bounds=scipy_bounds,
seed=de_seed,
**de_kwargs_use,
)
## Extract best parameters from DE result
x_best = self._de_result.x
self.best_params = {}
param_idx = 0
for name, cfg in _cached_configs.items():
if cfg.optimize_power:
self.best_params[f'power_{name}'] = float(x_best[param_idx])
param_idx += 1
else:
## Non-optimized metrics get identity power (no transform)
self.best_params[f'power_{name}'] = None
## p_norm is after all power params
p_idx = sum(1 for pt, _ in self._de_param_layout if pt == 'power')
self.best_params['p_norm'] = float(x_best[p_idx])
## Sigmoid params (frozen or from DE)
sig_param_idx = p_idx + 1 ## sigmoid params start after p_norm
for name, cfg in _cached_configs.items():
if cfg.optimize_sigmoid:
if _frozen_sig is not None and name in _frozen_sig:
self.best_params[f'sig_{name}_kwargs'] = {
'mu': _frozen_sig[name]['mu'],
'b': _frozen_sig[name]['b'],
}
else:
## Extract unfrozen sigmoid params from DE result
self.best_params[f'sig_{name}_kwargs'] = {
'mu': float(x_best[sig_param_idx]),
'b': float(x_best[sig_param_idx + 1]),
}
sig_param_idx += 2
else:
## Non-sigmoid metrics get None (no sigmoid applied)
self.best_params[f'sig_{name}_kwargs'] = None
self.kwargs_makeConjunctiveDistanceMatrix_best = dict(self.best_params)
print(
f'Completed DE parameter search. '
f'Best value: {self._de_result.fun:.2f}, '
f'evaluations: {self._de_result.nfev}, '
f'params: {self.best_params}'
) if self._verbose else None
return self.kwargs_makeConjunctiveDistanceMatrix_best
####################################################################
## Naive Bayes calibration and sigmoid estimation
####################################################################
def _calibrate_feature_1d(
self,
s_data: torch.Tensor,
intra_mask: torch.Tensor,
n_bins: int,
smoother: 'helpers.Convolver_1d',
prob_clip: Tuple[float, float] = (1e-4, 1 - 1e-4),
) -> Dict[str, torch.Tensor]:
"""
Estimate P(same | s_k) for a single similarity feature using
histogram subtraction.
Given raw similarity values for all pairs and the intra-session
(known-different) subset, estimates the "same" distribution as the
residual after subtracting the scaled intra-session distribution
from the overall distribution. Monotonicity is enforced (higher
similarity → higher P(same)).
RH 2025
Args:
s_data (torch.Tensor):
1D tensor of raw similarity values, shape ``(nnz,)``.
intra_mask (torch.Tensor):
Boolean tensor, shape ``(nnz,)``. ``True`` for
intra-session (known-different) pairs.
n_bins (int):
Number of histogram bins.
smoother (helpers.Convolver_1d):
1D convolver for smoothing histogram counts.
prob_clip (Tuple[float, float]):
Clamp P(same) to ``[lo, hi]`` to avoid logit divergence.
Returns:
(Dict[str, torch.Tensor]):
calibration (Dict[str, torch.Tensor]):
Dictionary with keys ``'edges'``, ``'counts_all'``,
``'counts_diff'``, ``'counts_diff_smooth'``,
``'counts_same'``, ``'p_same_bins'``.
"""
n_all = s_data.shape[0]
n_intra = int(intra_mask.sum().item())
scale = n_all / n_intra
## Bin edges spanning the data range with small margin
lo = float(s_data.min()) - 1e-6
hi = float(s_data.max()) + 1e-6
edges = torch.linspace(lo, hi, n_bins + 1, dtype=torch.float32)
## Histogram all values and intra-session values
counts_all, _ = torch.histogram(s_data, edges)
counts_intra, _ = torch.histogram(s_data[intra_mask], edges)
## Scale intra counts to estimate the full "different" distribution
counts_diff = counts_intra * scale
## "Same" distribution = residual, clamped non-negative.
## Do NOT smooth counts_same — the smoothing kernel bleeds nonzero
## mass into the left tail where the true same-count is zero, creating
## an artificial P(same) floor. Raw residual preserves zeros.
counts_same = torch.clamp(counts_all - counts_diff, min=0)
## Smooth only the "different" distribution (estimating the smooth
## population envelope). counts_same stays raw/clamped.
counts_diff_smooth = smoother.convolve(counts_diff)
## P(same | bin) = counts_same / (counts_same + counts_diff_smooth)
p_same_bins = counts_same / (counts_same + counts_diff_smooth + 1e-10)
## Enforce monotonic increasing via isotonic regression, weighted by
## total evidence per bin. Isotonic regression finds the best
## monotonically-increasing fit — sparse tail bins (with few
## observations) get negligible influence, avoiding artificial
## P(same) floors.
ir = sklearn.isotonic.IsotonicRegression(
increasing=True,
y_min=float(prob_clip[0]),
y_max=float(prob_clip[1]),
)
evidence_weights = (counts_all + counts_diff_smooth).numpy()
evidence_weights = np.maximum(evidence_weights, 1e-10)
p_same_np = ir.fit_transform(
X=np.arange(n_bins, dtype=np.float64),
y=p_same_bins.numpy().astype(np.float64),
sample_weight=evidence_weights.astype(np.float64),
)
p_same_bins = torch.as_tensor(p_same_np, dtype=torch.float32)
## Store all tensors as numpy so that serializable_dict can
## preserve them (torch tensors are not in the allowed library list).
## Call sites that need torch tensors convert with torch.as_tensor().
return {
'edges': edges.numpy(),
'counts_all': counts_all.numpy(),
'counts_diff': counts_diff.numpy(),
'counts_diff_smooth': counts_diff_smooth.numpy(),
'counts_same': counts_same.numpy(),
'p_same_bins': p_same_bins.numpy(),
}
[docs]
def make_naive_bayes_distance_matrix(
self,
n_bins: Optional[int] = None,
smoothing_window_bins: Optional[int] = None,
prob_clip: Tuple[float, float] = (1e-4, 1 - 1e-4),
) -> Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array, Dict[str, Any]]:
"""
Compute pairwise distance matrix using independent per-feature
calibration combined via naive Bayes.
For each similarity feature k (SF, NN, SWT), estimates the posterior
``P(same | s_k)`` from a 1D histogram of similarity values, using
the intra-session (known-different) distribution as reference. The
per-feature posteriors are combined under conditional independence:
.. math::
\\text{logit}(P(\\text{same} | \\mathbf{s})) = \\sum_k \\text{logit}(P(\\text{same} | s_k)) - (K-1) \\cdot \\text{logit}(\\pi)
where :math:`\\pi` is the estimated prior P(same) and K is the
number of features.
**No iterative optimization** — just histogram + lookup. Typically
completes in under 1 second even on large datasets.
RH 2025
Args:
n_bins (Optional[int]):
Number of histogram bins per feature. If ``None``, uses
``self.n_bins``.
smoothing_window_bins (Optional[int]):
Smoothing window. If ``None``, uses ``self.smooth_window``.
prob_clip (Tuple[float, float]):
Clamp P(same|s_k) to ``[lo, hi]`` before logit.
Returns:
(Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array, Dict]):
dConj (scipy.sparse.csr_array):
Distance matrix ``d = 1 - P(same|all)``.
sConj (scipy.sparse.csr_array):
Similarity matrix ``s = P(same|all)``.
calibrations (Dict[str, Any]):
Diagnostic dict with per-feature calibrations,
prior, and combined P(same).
"""
self.params['make_naive_bayes_distance_matrix'] = self._locals_to_params(
locals_dict=locals(),
keys=['n_bins', 'smoothing_window_bins', 'prob_clip'],
)
n_bins = self.n_bins if n_bins is None else n_bins
smooth_window = self.smooth_window if smoothing_window_bins is None else smoothing_window_bins
print('Computing naive Bayes distance matrix...') if self._verbose else None
## Precompute intra-session mask
if not hasattr(self, '_intra_mask') or self._intra_mask is None:
self._precompute_intra_mask()
intra_mask = torch.as_tensor(self._intra_mask)
## Build smoother (shared across features)
smoother = helpers.Convolver_1d(
kernel=torch.ones(smooth_window),
length_x=n_bins,
pad_mode='same',
correct_edge_effects=True,
device='cpu',
)
## Features to calibrate: iterate over all similarity metrics
features_raw = {
name: torch.as_tensor(sim.data, dtype=torch.float32)
for name, sim in self.similarities.items()
}
calibrations = {'features': {}}
nnz = self._s_sparsity.nnz
logit_sum = torch.zeros(nnz, dtype=torch.float32)
## Calibrate each feature independently
for name, s_data in features_raw.items():
cal = self._calibrate_feature_1d(
s_data=s_data,
intra_mask=intra_mask,
n_bins=n_bins,
smoother=smoother,
prob_clip=prob_clip,
)
## Look up P(same) for each pair from its histogram bin.
## cal values are numpy arrays; convert to torch for the lookup
## then store result as numpy for serialization safety.
edges_t = torch.as_tensor(cal['edges'])
p_same_bins_t = torch.as_tensor(cal['p_same_bins'])
bin_idx = torch.searchsorted(
edges_t[1:-1].contiguous(), s_data,
)
bin_idx = torch.clamp(bin_idx, 0, n_bins - 1)
p_same_per_pair = p_same_bins_t[bin_idx] ## torch, shape (nnz,)
## Accumulate logit for naive Bayes combination
logit_p = torch.log(p_same_per_pair / (1.0 - p_same_per_pair))
logit_sum += logit_p
## Store as numpy so serializable_dict preserves it
cal['p_same_per_pair'] = p_same_per_pair.numpy()
calibrations['features'][name] = cal
print(
f' {name}: P(same) range '
f'[{p_same_per_pair.min():.4f}, {p_same_per_pair.max():.4f}], '
f'mean={p_same_per_pair.mean():.4f}'
) if self._verbose else None
## Estimate prior P(same)
K = len(features_raw)
prior_estimates = []
for cal in calibrations['features'].values():
total_same = cal['counts_same'].sum().item()
total = total_same + cal['counts_diff_smooth'].sum().item()
if total > 0:
prior_estimates.append(total_same / total)
prior = float(np.mean(prior_estimates)) if prior_estimates else 0.5
prior = float(np.clip(prior, prob_clip[0], prob_clip[1]))
calibrations['prior'] = prior
## Naive Bayes log-odds combination
logit_prior = float(np.log(prior / (1.0 - prior)))
logit_combined = logit_sum - (K - 1) * logit_prior
## Convert back to probability
p_same_combined = torch.sigmoid(logit_combined).numpy()
calibrations['p_same_combined'] = p_same_combined
## Build sparse similarity and distance matrices
sConj = self._s_sparsity.copy()
sConj.data = p_same_combined.astype(np.float64)
dConj = sConj.copy()
dConj.data = 1.0 - dConj.data
## Store for downstream use
self.dConj = dConj
self.sConj = sConj
self.calibrations_naive_bayes = calibrations
print(
f' Combined P(same): mean={p_same_combined.mean():.4f}, '
f'prior={prior:.4f}, '
f'P(same)>0.5: {(p_same_combined > 0.5).sum()}/{nnz} '
f'({(p_same_combined > 0.5).mean() * 100:.1f}%)'
) if self._verbose else None
return dConj, sConj, calibrations
def _estimate_sigmoid_params(self) -> Dict[str, Dict[str, float]]:
"""
Estimate sigmoid parameters (mu, b) for NN and SWT from
NB calibration curves using Fisher's linear discriminant.
For each feature, finds the sigmoid ``sigma(b * (s - mu))`` that
best separates "same" and "different" distributions in the
calibration histogram. Uses a grid search over (mu, b) to maximize
the Fisher discriminant ratio in sigmoid-transformed space.
Requires :meth:`make_naive_bayes_distance_matrix` to have been
called first.
RH 2025
Returns:
(Dict[str, Dict[str, float]]):
sigmoid_params (Dict[str, Dict[str, float]]):
Mapping from feature name to ``{'mu': float, 'b': float}``.
"""
assert hasattr(self, 'calibrations_naive_bayes') and self.calibrations_naive_bayes is not None, (
"make_naive_bayes_distance_matrix() must be called before "
"_estimate_sigmoid_params()."
)
result = {}
for name, cfg in self._metric_configs.items():
if not cfg.optimize_sigmoid:
continue
cal = self.calibrations_naive_bayes['features'][name]
## cal values are numpy arrays (stored that way for serialization)
edges = np.asarray(cal['edges'])
counts_same = np.asarray(cal['counts_same'])
counts_diff = np.asarray(cal['counts_diff_smooth'])
## Bin centers in similarity space, shape (n_bins,)
centers_np = (edges[:-1] + edges[1:]) / 2.0
## Normalized distribution weights, shape (n_bins,)
w_same = counts_same / (counts_same.sum() + 1e-10)
w_diff = counts_diff / (counts_diff.sum() + 1e-10)
## Vectorized grid search over (mu, b) to maximize Fisher
## discriminant in sigmoid-transformed space.
## Grid shapes: mu (M,), b (B,) → sig_vals (M, B, n_bins)
mu_grid = np.linspace(
float(centers_np.min()), float(centers_np.max()), 50,
)
b_grid = np.linspace(0.5, 10.0, 30)
## Broadcasting: (M,1,1) * ((1,1,n_bins) - (M,1,1))
sig_vals = 1.0 / (1.0 + np.exp(
-b_grid[None, :, None] * (centers_np[None, None, :] - mu_grid[:, None, None])
)) ## shape (M, B, n_bins)
## Weighted moments in sigmoid-transformed space
mu_same_sig = np.sum(w_same[None, None, :] * sig_vals, axis=2) ## (M, B)
mu_diff_sig = np.sum(w_diff[None, None, :] * sig_vals, axis=2) ## (M, B)
var_same_sig = np.sum(w_same[None, None, :] * (sig_vals - mu_same_sig[:, :, None]) ** 2, axis=2)
var_diff_sig = np.sum(w_diff[None, None, :] * (sig_vals - mu_diff_sig[:, :, None]) ** 2, axis=2)
## Fisher discriminant ratio, shape (M, B)
denom = var_same_sig + var_diff_sig + 1e-12
fisher_grid = (mu_same_sig - mu_diff_sig) ** 2 / denom
## Find best (mu, b)
best_idx = np.unravel_index(fisher_grid.argmax(), fisher_grid.shape)
best_mu = float(mu_grid[best_idx[0]])
best_b = float(b_grid[best_idx[1]])
best_fisher = float(fisher_grid[best_idx])
result[name] = {'mu': best_mu, 'b': best_b}
print(
f' Sigmoid estimate for {name}: '
f'mu={best_mu:.4f}, b={best_b:.2f} '
f'(Fisher={best_fisher:.4f})'
) if self._verbose else None
return result
[docs]
def make_pruned_similarity_graphs(
self,
convert_to_probability: bool = False,
stringency: float = 1.0,
mixing_params: Optional[Dict] = None,
d_cutoff: Optional[float] = None,
) -> None:
"""
Constructs pruned similarity graphs.
RH 2023
Args:
convert_to_probability (bool):
Whether to convert the distance and similarity graphs to
probability, *p(different)* and *p(same)*, respectively.
(Default is ``False``)
stringency (float):
Modifies the threshold for pruning the distance matrix. A higher
value results in less pruning, a lower value leads to more
pruning. This value is multiplied by the inferred threshold to
generate a new one. (Default is *1.0*)
mixing_params (Optional[Dict]):
Mixing parameters for
``self.make_conjunctive_distance_matrix``. If ``None``,
the best parameters found using ``self.find_optimal_parameters``
are used. Use ``'precomputed'`` to use a previously stored
``self.dConj``. (Default is ``None``)
d_cutoff (Optional[float]):
The cutoff distance for pruning the distance matrix. If
``None``, then the optimal cutoff distance is inferred. (Default
is ``None``)
"""
## Store parameter (but not data) args as attributes
self.params['make_pruned_similarity_graphs'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'convert_to_probability',
'stringency',
'mixing_params',
],
)
## If 'precomputed', use self.dConj/sConj set by a prior call
## (e.g. make_naive_bayes_distance_matrix). Otherwise, compute
## the conjunctive distance matrix from mixing parameters.
if mixing_params == 'precomputed':
assert hasattr(self, 'dConj') and self.dConj is not None, (
"mixing_params='precomputed' requires self.dConj to be set "
"(call make_naive_bayes_distance_matrix first)."
)
elif mixing_params is None:
if hasattr(self, 'kwargs_makeConjunctiveDistanceMatrix_best'):
mixing_params = self.kwargs_makeConjunctiveDistanceMatrix_best
else:
mixing_params = {'p_norm': -4.0}
for name in self.similarities:
mixing_params[f'power_{name}'] = 1.0 ## identity (no transform)
mixing_params[f'sig_{name}_kwargs'] = {'mu': 0.5, 'b': 0.5}
warnings.warn(f'No mixing_params provided. Using defaults: {mixing_params}')
if mixing_params != 'precomputed':
self.dConj, self.sConj, self._activated_data = self.make_conjunctive_distance_matrix(
similarities=self.similarities,
mixing_params=mixing_params,
)
dens_same_crop, dens_same, dens_diff, dens_all, edges, d_crossover = self._separate_diffSame_distributions(self.dConj)
if convert_to_probability:
## convert into probabilities
### first smooth dens_diff. (dens_same is already smoothed)
dens_diff_smooth = self._fn_smooth(dens_diff)
### second, compute the probability of each bin
prob_same = (dens_same / (dens_same + dens_diff_smooth)).numpy()
### force to be monotonic decreasing
prob_same = np.maximum.accumulate(prob_same[::-1])[::-1]
### third, append 0 to the end
prob_same = np.append(prob_same, 0)
### third, convert self.dConj to probabilities using interpolation
import scipy.interpolate
fn_interp = scipy.interpolate.interp1d(edges, prob_same, kind='linear', fill_value='extrapolate')
self.sConj.data = fn_interp(self.dConj.data)
self.dConj.data = 1 - self.sConj.data
d_crossover = 1 - fn_interp(d_crossover)
self.distributions_mixing = {
'mixing_params': mixing_params,
'dens_same_crop': dens_same_crop,
'dens_same': dens_same,
'dens_diff': dens_diff,
'dens_all': dens_all,
'edges': edges,
'd_crossover': d_crossover,
}
## Copy all similarity matrices for pruning
sims_copy = {name: s.copy() for name, s in self.similarities.items()}
ssesh = self.s_sesh.copy()
min_d = np.nanmin(self.dConj.data)
if d_cutoff is None:
range_d = d_crossover - min_d
self.d_cutoff = min_d + range_d * stringency
else:
self.d_cutoff = d_cutoff
print(f'Pruning similarity graphs with d_cutoff = {self.d_cutoff}...') if self._verbose else None
self.graph_pruned = self.dConj.copy()
self.graph_pruned.data = self.graph_pruned.data < self.d_cutoff
self.graph_pruned.eliminate_zeros()
def prune(s, graph_pruned):
import scipy.sparse
if s is None:
return None
s_pruned = scipy.sparse.csr_array(graph_pruned.shape, dtype=np.float32)
s_pruned[graph_pruned] = s[graph_pruned]
s_pruned = s_pruned.tocsr()
return s_pruned
## Prune all similarity matrices
self.similarities_pruned = {
name: prune(s, self.graph_pruned) for name, s in sims_copy.items()
}
self.s_sesh_pruned = prune(ssesh, self.graph_pruned)
self.dConj_pruned = prune(self.dConj, self.graph_pruned)
self.sConj_pruned = prune(self.sConj, self.graph_pruned)
[docs]
def apply_weighted_jaccard(
self,
s_conj: Optional[scipy.sparse.csr_array] = None,
d_conj: Optional[scipy.sparse.csr_array] = None,
alpha: float = 1.0,
) -> Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array]:
"""
Apply weighted Jaccard preprocessing to a similarity graph.
Returns new similarity and distance matrices without modifying
``self``.
Can be called in two ways:
1. **From stored state** (no args): uses ``self.sConj_pruned`` from
``make_pruned_similarity_graphs()``.
2. **Purely functional**: pass ``s_conj`` or ``d_conj`` directly.
Exactly one must be provided; the other is derived as ``1 - x``.
The weighted Jaccard replaces each pairwise similarity with a
measure of shared neighborhood structure:
.. math::
J_w(i,j) = \\frac{\\sum_k \\min(s_{ik}, s_{jk})}
{\\sum_k \\max(s_{ik}, s_{jk})}
This amplifies within-cluster similarity and suppresses cross-cluster
noise. See :func:`weighted_jaccard_similarity` for details.
Args:
s_conj (Optional[scipy.sparse.csr_array]):
Similarity matrix to transform. If ``None``, uses
``self.sConj_pruned``. Mutually exclusive with ``d_conj``.
d_conj (Optional[scipy.sparse.csr_array]):
Distance matrix to transform (converted to similarity via
``1 - d``). If ``None``, uses ``self.sConj_pruned``.
Mutually exclusive with ``s_conj``.
alpha (float):
Blending weight between Jaccard and original similarity.
``alpha=1.0`` uses pure Jaccard (default). ``alpha=0.0``
returns copies of the originals. Values in between give a
linear blend: ``s_final = alpha * s_jaccard + (1-alpha) * s_original``.
Returns:
(Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array]):
sConj_jaccard (scipy.sparse.csr_array):
Jaccard-refined similarity matrix. Same sparsity as input.
dConj_jaccard (scipy.sparse.csr_array):
Corresponding distance matrix (``1 - sConj_jaccard``).
"""
assert not (s_conj is not None and d_conj is not None), (
"Provide s_conj or d_conj, not both."
)
assert 0.0 <= alpha <= 1.0, f"alpha must be in [0, 1], got {alpha}"
## Resolve input similarity matrix
if s_conj is not None:
s = s_conj
elif d_conj is not None:
s = d_conj.copy()
s.data = 1.0 - s.data
else:
assert hasattr(self, 'sConj_pruned') and self.sConj_pruned is not None, (
"No input provided and self.sConj_pruned is not set. "
"Either pass s_conj/d_conj or call make_pruned_similarity_graphs() first."
)
s = self.sConj_pruned
if alpha == 0.0:
s_out = s.copy()
d_out = s.copy()
d_out.data = 1.0 - d_out.data
return s_out, d_out
s_jaccard = weighted_jaccard_similarity(s)
if alpha < 1.0:
## Blend: s_final = alpha * s_jaccard + (1 - alpha) * s_original
s_jaccard.data = alpha * s_jaccard.data + (1.0 - alpha) * s.data
d_jaccard = s_jaccard.copy()
d_jaccard.data = 1.0 - d_jaccard.data
return s_jaccard, d_jaccard
[docs]
def fit(
self,
d_conj: scipy.sparse.csr_array,
session_bool: np.ndarray,
min_cluster_size: int = 2,
max_cluster_size: Optional[int] = None,
min_samples: Optional[int] = None,
n_iter_violationCorrection: int = 5,
cluster_selection_method: str = 'leaf',
cluster_selection_persistence: float = 0.0,
d_clusterMerge: Optional[float] = None,
alpha: float = 0.999,
split_intraSession_clusters: bool = True,
discard_failed_pruning: bool = True,
n_steps_clusterSplit: int = 100,
backend: str = 'fast_hdbscan',
algorithm: str = 'kruskal',
rescue_noise: bool = True,
) -> np.ndarray:
"""
Fits clustering using HDBSCAN with same-session constraint enforcement.
By default (``backend='fast_hdbscan'``), uses ``fast_hdbscan.HDBSCAN``
with group-label cannot-link constraints. Each ROI's session index
is passed as a group label; same-session ROIs cannot co-cluster.
This uses O(N) memory (an int32 vector) instead of the O(N^2) sparse
cannot-link matrix, and O(1) per-merge conflict checks via bitmask.
Transitive constraints are handled correctly: if components A and B
are merged and both contain session-0 ROIs, the merge is blocked.
Set ``backend='legacy'`` to use the original ``hdbscan.HDBSCAN`` with
iterative violation correction via dendrogram walk-back.
RH 2023 / 2025
Args:
d_conj (scipy.sparse.csr_array):
Conjunctive distance matrix.
session_bool (np.ndarray):
Boolean array indicating which ROIs belong to which session.
Shape: *(n_rois, n_sessions)*
min_cluster_size (int):
Minimum cluster size to be considered a cluster. Can be 'all'.
(Default is *2*)
max_cluster_size (Optional[int]):
Maximum cluster size. Clusters larger than this are split. If
``None``, defaults to ``n_sessions`` (one ROI per session),
which is the natural constraint for neuron tracking. Set to a
larger value to allow clusters spanning a subset of sessions.
(Default is ``None``)
min_samples (Optional[int]):
Number of neighbors a point needs to be considered "core"
(non-noise) by HDBSCAN. Controls the density threshold
independently of ``min_cluster_size``. Lower values → fewer
noise points. If ``None``, defaults to ``min_cluster_size``
(HDBSCAN default behavior). (Default is ``None``)
n_iter_violationCorrection (int):
Number of iterations to correct for clusters with multiple ROIs
per session. Only used with ``backend='legacy'``.
(Default is *5*)
cluster_selection_method (str):
Cluster selection method. Either ``'leaf'`` or ``'eom'``. 'leaf'
leans towards smaller clusters, 'eom' towards larger clusters.
(Default is ``'leaf'``)
cluster_selection_persistence (float):
Minimum stability (persistence) a cluster must have to survive
selection. Clusters below this threshold are folded into their
parent. Higher values → fewer but more stable clusters. Only
used with ``backend='fast_hdbscan'``. (Default is *0.0*)
d_clusterMerge (Optional[float]):
Distance threshold for merging clusters
(``cluster_selection_epsilon`` in HDBSCAN). Clusters separated
by less than this distance are merged. If ``None``, defaults to
``self.d_cutoff`` (the pruning threshold from
``make_pruned_similarity_graphs``), which is the inferred
same/different decision boundary. Falls back to
``mean + 1*std`` of the distance data if ``d_cutoff`` is not
available. (Default is ``None``)
alpha (float):
Alpha value. Only used with ``backend='legacy'``. (Default is
*0.999*)
split_intraSession_clusters (bool):
If ``True``, clusters containing ROIs from multiple sessions
will be split. Only used with ``backend='legacy'``. (Default is
``True``)
discard_failed_pruning (bool):
If ``True``, clusters failing to prune are set to -1. Only used
with ``backend='legacy'``. (Default is ``True``)
n_steps_clusterSplit (int):
Number of steps for splitting clusters with multiple ROIs from
the same session. Only used with ``backend='legacy'``. (Default
is *100*)
backend (str):
Which HDBSCAN implementation to use: \n
* ``'fast_hdbscan'``: Use fast_hdbscan with native cannot-link
constraints (default).
* ``'legacy'``: Use legacy hdbscan with iterative violation
correction. \n
(Default is ``'fast_hdbscan'``)
algorithm (str):
Algorithm for fast_hdbscan MST construction. Only used with
``backend='fast_hdbscan'``: \n
* ``'kruskal'``: Kruskal DSU on full CSR edge list. Supports
cannot-link with any metric.
* ``'boruvka'``: Boruvka parallel MST. Supports cannot-link
only with ``metric='precomputed'``. \n
(Default is ``'kruskal'``)
rescue_noise (bool):
If ``True``, run a post-HDBSCAN noise rescue pass that
assigns noise ROIs to nearby clusters (or nucleates new small
clusters) using a Kruskal-style sorted-edge traversal with
DSU bitmask cannot-link constraints. Only used with
``backend='fast_hdbscan'``. (Default is ``True``)
Returns:
(np.ndarray):
labels (np.ndarray):
Cluster labels for each ROI, shape: *(n_rois_total)*
"""
## Store parameter (but not data) args as attributes
self.params['fit'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'min_cluster_size',
'max_cluster_size',
'min_samples',
'n_iter_violationCorrection',
'cluster_selection_method',
'cluster_selection_persistence',
'd_clusterMerge',
'alpha',
'split_intraSession_clusters',
'discard_failed_pruning',
'n_steps_clusterSplit',
'backend',
'algorithm',
'rescue_noise',
],
)
## Resolve d_clusterMerge default: use d_cutoff from pruning if available
if d_clusterMerge is None:
if hasattr(self, 'd_cutoff') and self.d_cutoff is not None:
d_clusterMerge = float(self.d_cutoff)
## Otherwise each backend falls back to mean + 1*std heuristic
if backend == 'fast_hdbscan':
return self._fit_fast_hdbscan(
d_conj=d_conj,
session_bool=session_bool,
min_cluster_size=min_cluster_size,
max_cluster_size=max_cluster_size,
min_samples=min_samples,
cluster_selection_method=cluster_selection_method,
cluster_selection_persistence=cluster_selection_persistence,
d_clusterMerge=d_clusterMerge,
algorithm=algorithm,
rescue_noise=rescue_noise,
)
elif backend == 'legacy':
return self._fit_legacy_hdbscan(
d_conj=d_conj,
session_bool=session_bool,
min_cluster_size=min_cluster_size,
max_cluster_size=max_cluster_size,
min_samples=min_samples,
n_iter_violationCorrection=n_iter_violationCorrection,
cluster_selection_method=cluster_selection_method,
d_clusterMerge=d_clusterMerge,
alpha=alpha,
split_intraSession_clusters=split_intraSession_clusters,
discard_failed_pruning=discard_failed_pruning,
n_steps_clusterSplit=n_steps_clusterSplit,
)
else:
raise ValueError(
f"backend must be 'fast_hdbscan' or 'legacy'. Got: {backend!r}"
)
def _fit_fast_hdbscan(
self,
d_conj: scipy.sparse.csr_array,
session_bool: np.ndarray,
min_cluster_size: int = 2,
max_cluster_size: Optional[int] = None,
min_samples: Optional[int] = None,
cluster_selection_method: str = 'leaf',
cluster_selection_persistence: float = 0.0,
d_clusterMerge: Optional[float] = None,
algorithm: str = 'kruskal',
rescue_noise: bool = True,
) -> np.ndarray:
"""
Fit clustering using fast_hdbscan with group-label cannot-link
constraints, optionally followed by a noise rescue pass.
Each ROI is assigned its session index as a group label. Samples
sharing the same group label (i.e., same session) cannot co-cluster.
This uses a bitmask per DSU component for O(1) conflict checks,
replacing the previous O(N^2) sparse cannot-link matrix.
RH 2025
Args:
d_conj (scipy.sparse.csr_array):
Conjunctive distance matrix. Shape: *(n_rois, n_rois)*.
session_bool (np.ndarray):
Boolean array indicating which ROIs belong to which session.
Shape: *(n_rois, n_sessions)*. Each row should contain
exactly one ``True``.
min_cluster_size (int):
Minimum cluster size. (Default is *2*)
max_cluster_size (Optional[int]):
Maximum cluster size. If ``None``, defaults to ``n_sessions``.
(Default is ``None``)
min_samples (Optional[int]):
Number of neighbors for core-point density estimation. If
``None``, defaults to ``min_cluster_size``.
(Default is ``None``)
cluster_selection_method (str):
``'leaf'`` or ``'eom'``. (Default is ``'leaf'``)
cluster_selection_persistence (float):
Minimum cluster stability threshold. (Default is *0.0*)
d_clusterMerge (Optional[float]):
Distance threshold for merging clusters. If ``None``, computed
as mean + 1*std of the distance data. (Default is ``None``)
algorithm (str):
MST construction algorithm. ``'kruskal'`` or ``'boruvka'``.
(Default is ``'kruskal'``)
rescue_noise (bool):
If ``True``, run a post-HDBSCAN noise rescue pass.
(Default is ``True``)
Returns:
(np.ndarray):
labels (np.ndarray):
Cluster labels for each ROI, shape: *(n_rois_total)*.
"""
## Mask to inter-session pairs only (same as legacy)
d = d_conj.copy().multiply(self.s_sesh)
if d.nnz == 0:
print('No edges in graph. Returning all -1 labels.') if self._verbose else None
self.labels = np.ones(d.shape[0], dtype=int) * -1
return self.labels
n_sessions = session_bool.shape[1]
if min_cluster_size == 'all':
min_cluster_size = n_sessions
print(f'Setting min_cluster_size to {min_cluster_size} (all ROIs in a session)') if self._verbose else None
## Resolve max_cluster_size default: one ROI per session
if max_cluster_size is None:
max_cluster_size = n_sessions
## Auto-estimate d_clusterMerge from distance statistics
d_clusterMerge = float(np.mean(d.data) + 1 * np.std(d.data)) if d_clusterMerge is None else float(d_clusterMerge)
## Build group-label cannot-link vector: explicit session index per ROI.
## This avoids relying on ROIs being stored in contiguous session
## blocks; fast_hdbscan expects one group label per sample.
n_sessions_per_roi = np.asarray(session_bool.sum(axis=1)).ravel()
assert np.all(n_sessions_per_roi == 1), (
"session_bool must contain exactly one True per ROI when using "
"fast_hdbscan cannot_link_groups"
)
cannot_link_groups = np.asarray(np.argmax(session_bool, axis=1), dtype=np.int32)
print('Fitting with fast_hdbscan (group-label cannot-link constraints)') if self._verbose else None
## Run fast_hdbscan with group-label cannot-link constraints
import fast_hdbscan
self.hdbs = fast_hdbscan.HDBSCAN(
min_cluster_size=min_cluster_size,
min_samples=min_samples,
cluster_selection_epsilon=d_clusterMerge,
cluster_selection_persistence=cluster_selection_persistence,
max_cluster_size=max_cluster_size,
metric='precomputed',
algorithm=algorithm,
cannot_link_groups=cannot_link_groups,
cluster_selection_method=cluster_selection_method,
)
## fast_hdbscan's Kruskal algorithm handles disconnected components
## natively (they become separate clusters), so the
## attach_fully_connected_node hack used by the legacy backend is
## not needed here.
self.hdbs.fit(d)
labels = self.hdbs.labels_.copy()
self._fit_used_fully_connected_node = False
## Report violations (should be zero with group-label cannot-link)
unique_labels = np.unique(labels)
unique_labels = unique_labels[unique_labels > -1]
violations_labels = unique_labels[np.array([
np.unique(cannot_link_groups[mask := (labels == u)]).size < mask.sum()
for u in unique_labels
], dtype=bool)]
## A cluster violates if it has duplicate session indices.
## Equivalently: n_unique_sessions < n_rois_in_cluster.
self.violations_labels = violations_labels
if self._verbose:
n_violations = len(violations_labels)
print(f'Session violations after fast_hdbscan: {n_violations} clusters, d_clusterMerge={d_clusterMerge:.2f}')
## Phase 2: noise rescue — assign noise ROIs to nearby clusters or
## nucleate new small clusters via Kruskal DSU with bitmask constraints.
## Operates on the same inter-session-masked distance matrix `d` and
## uses d_clusterMerge as the edge distance cutoff.
if rescue_noise and np.any(labels == -1):
labels = self.rescue_noise(
d_conj=d,
labels=labels,
session_bool=session_bool,
d_cutoff=d_clusterMerge,
)
## Post-processing: squeeze labels, remove too-small clusters
labels = helpers.squeeze_integers(labels)
## Set clusters below min_cluster_size to noise. This catches both
## singletons and small nucleated clusters from noise rescue.
u, c = np.unique(labels, return_counts=True)
labels[np.isin(labels, u[c < min_cluster_size])] = -1
labels = helpers.squeeze_integers(labels)
self.labels = labels
return self.labels
[docs]
def rescue_noise(
self,
d_conj: scipy.sparse.csr_array,
labels: np.ndarray,
session_bool: np.ndarray,
d_cutoff: float,
) -> np.ndarray:
"""
Assign noise ROIs (``label == -1``) to nearby clusters or nucleate
new small clusters, respecting same-session cannot-link constraints.
Uses a Kruskal-style sorted-edge traversal with DSU and bitmask
constraints (Phase 2 of two-phase clustering). See
:func:`noise_rescue_kruskal` for algorithm details.
Non-mutating: does not modify ``self.labels`` or any stored state.
Args:
d_conj (scipy.sparse.csr_array):
Inter-session masked distance matrix. Shape: *(n_rois, n_rois)*.
labels (np.ndarray):
Phase 1 cluster labels from HDBSCAN. Shape: *(n_rois,)*.
``-1`` = noise.
session_bool (np.ndarray):
Boolean array, shape *(n_rois, n_sessions)*. Each row has
exactly one ``True``.
d_cutoff (float):
Maximum edge distance to accept for noise rescue.
Returns:
(np.ndarray):
new_labels (np.ndarray):
Updated cluster labels after noise rescue. Shape:
*(n_rois,)*.
"""
n_sessions = session_bool.shape[1]
group_labels = np.asarray(np.argmax(session_bool, axis=1), dtype=np.int32)
n_noise_before = np.sum(labels == -1)
new_labels = noise_rescue_kruskal(
d_conj=d_conj,
labels=labels,
group_labels=group_labels,
n_groups=n_sessions,
d_cutoff=d_cutoff,
)
n_noise_after = np.sum(new_labels == -1)
n_rescued = n_noise_before - n_noise_after
if self._verbose:
print(f'Noise rescue: {n_rescued}/{n_noise_before} noise ROIs rescued, {n_noise_after} remain noise')
return new_labels
def _fit_legacy_hdbscan(
self,
d_conj: scipy.sparse.csr_array,
session_bool: np.ndarray,
min_cluster_size: int = 2,
max_cluster_size: Optional[int] = None,
min_samples: Optional[int] = None,
n_iter_violationCorrection: int = 5,
cluster_selection_method: str = 'leaf',
d_clusterMerge: Optional[float] = None,
alpha: float = 0.999,
split_intraSession_clusters: bool = True,
discard_failed_pruning: bool = True,
n_steps_clusterSplit: int = 100,
) -> np.ndarray:
"""
Legacy clustering using hdbscan==0.8.41 with iterative violation
correction. Preserved for backward compatibility.
Fits clustering using a modified HDBSCAN clustering algorithm.
The approach is to use HDBSCAN but avoid having clusters with multiple
ROIs from the same session. This is achieved by repeating three
steps: \n
1. Fit HDBSCAN to the data.
2. Identify clusters that have multiple ROIs from the same session
and walk back down the dendrogram until those clusters are split up
into non-violating clusters.
3. Disconnect graph edges between ROIs within each new cluster and all
other ROIs outside the cluster that are from the same session. \n
RH 2023
Args:
d_conj (scipy.sparse.csr_array):
Conjunctive distance matrix.
session_bool (np.ndarray):
Boolean array indicating which ROIs belong to which session.
Shape: *(n_rois, n_sessions)*
min_cluster_size (int):
Minimum cluster size to be considered a cluster. Can be 'all'.
(Default is *2*)
max_cluster_size (Optional[int]):
Maximum cluster size. If ``None``, defaults to ``n_sessions``.
(Default is ``None``)
min_samples (Optional[int]):
Number of neighbors for core-point density estimation. If
``None``, defaults to ``min_cluster_size``.
(Default is ``None``)
n_iter_violationCorrection (int):
Number of iterations to correct for clusters with multiple ROIs
per session. (Default is *5*)
cluster_selection_method (str):
Cluster selection method. Either ``'leaf'`` or ``'eom'``.
(Default is ``'leaf'``)
d_clusterMerge (Optional[float]):
Distance threshold for merging clusters. If ``None``, the
distance is calculated as the mean + 1*std of the conjunctive
distances. (Default is ``None``)
alpha (float):
Alpha value. Smaller values result in more clusters. (Default is
*0.999*)
split_intraSession_clusters (bool):
If ``True``, clusters containing ROIs from multiple sessions
will be split. (Default is ``True``)
discard_failed_pruning (bool):
If ``True``, clusters failing to prune are set to -1. (Default
is ``True``)
n_steps_clusterSplit (int):
Number of steps for splitting clusters. (Default is *100*)
Returns:
(np.ndarray):
labels (np.ndarray):
Cluster labels for each ROI, shape: *(n_rois_total)*
"""
import hdbscan
d = d_conj.copy().multiply(self.s_sesh)
if d.nnz == 0:
print('No edges in graph. Returning all -1 labels.') if self._verbose else None
self.labels = np.ones(d.shape[0], dtype=int) * -1
return self.labels
n_sessions = session_bool.shape[1]
if min_cluster_size == 'all':
min_cluster_size = n_sessions
print(f'Setting min_cluster_size to {min_cluster_size} (all ROIs in a session)') if self._verbose else None
## Resolve max_cluster_size default: one ROI per session
if max_cluster_size is None:
max_cluster_size = n_sessions
print('Fitting with HDBSCAN and splitting clusters with multiple ROIs per session') if self._verbose else None
for ii in tqdm(range(n_iter_violationCorrection)):
## Prep parameters for splitting clusters
d_clusterMerge = float(np.mean(d.data) + 1*np.std(d.data)) if d_clusterMerge is None else float(d_clusterMerge)
n_steps_clusterSplit = int(n_steps_clusterSplit)
max_dist=(d.max() - d.min()) * 1000
self.hdbs = hdbscan.HDBSCAN(
min_cluster_size=min_cluster_size,
min_samples=min_samples,
cluster_selection_epsilon=d_clusterMerge,
max_cluster_size=max_cluster_size,
metric='precomputed',
alpha=alpha,
algorithm='generic',
cluster_selection_method=cluster_selection_method,
max_dist=max_dist,
)
self.hdbs.fit(attach_fully_connected_node(
d,
dist_fullyConnectedNode=max_dist,
n_nodes=1,
))
labels = self.hdbs.labels_[:-1]
self.labels = labels
self._fit_used_fully_connected_node = True
print(f'Initial number of violating clusters: {len(np.unique(labels)[np.array([(session_bool[labels==u].sum(0)>1).sum().item() for u in np.unique(labels)]) > 0])}, d_clusterMerge={d_clusterMerge:.2f}') if self._verbose else None
## Split up labels with multiple ROIs per session
## The below code is a bit of a mess, but it works.
## It works by iteratively reducing the cutoff distance
## and splitting up violating clusters until there are
## no more violations.
if split_intraSession_clusters:
labels = labels.copy()
sb_t = torch.as_tensor(session_bool, dtype=torch.float32) ## (n_rois, n_sessions)
n = len(d_conj.data)
dcd = np.sort(d_conj.data)
cuts_all = np.sort(np.unique([dcd[0]/2] + [dcd[int(n*ii)] for ii in np.linspace(0., 1., num=n_steps_clusterSplit, endpoint=False)[::-1]] + [dcd[-1]]))[::-1]
for d_cut in cuts_all:
labels_t = torch.as_tensor(labels, dtype=torch.int64)
lab_u_t, lab_u_idx_t = torch.unique(labels_t, return_inverse=True) # (n_clusters,), (n_rois,)
lab_oneHot_t = helpers.idx_to_oneHot(lab_u_idx_t, dtype=torch.float32)
violations_labels = lab_u_t[((sb_t.T @ lab_oneHot_t) > 1.5).sum(0) > 0]
violations_labels = violations_labels[violations_labels > -1]
if len(violations_labels) == 0:
break
for l in violations_labels:
idx = np.where(labels==l)[0]
if d[idx][:,idx].nnz == 0:
labels[idx] = -1
labels_new = self.hdbs.single_linkage_tree_.get_clusters(
cut_distance=d_cut,
min_cluster_size=min_cluster_size,
)[:-1]
idx_toUpdate = np.isin(labels, violations_labels)
labels[idx_toUpdate] = labels_new[idx_toUpdate] + labels.max() + 5
labels[(labels_new == -1) * idx_toUpdate] = -1
if discard_failed_pruning:
labels[idx_toUpdate] = -1
if ii < n_iter_violationCorrection - 1:
## Find sessions represented in each cluster and set distances to ROIs in those sessions to 1.
d = d.tocsr()
for ii, l in enumerate(np.unique(labels)):
if l == -1:
continue
idx = np.where(labels==l)[0]
d_sub = d[idx][:,idx]
idx_grid = np.meshgrid(idx, idx)
## set distances of ROIs from same session to 0
sesh_to_exclude = 1 - (session_bool @ (session_bool[idx].max(0))) ## make a mask of sessions that are not represented in the cluster
d[idx,:] = d[idx,:].multiply(sesh_to_exclude[None,:]) ## set distances to ROIs from sessions represented in the cluster to 1
d[:,idx] = d[:,idx].multiply(sesh_to_exclude[:,None]) ## set distances to ROIs from sessions represented in the cluster to 1
d[idx_grid[0], idx_grid[1]] = d_sub ## undo the above for ROIs in the cluster
d = d.tocsr()
d.eliminate_zeros() ## remove zeros
labels = helpers.squeeze_integers(labels)
violations_labels = np.unique(labels)[np.array([(session_bool[labels==u].sum(0)>1).sum().item() for u in np.unique(labels)]) > 0]
violations_labels = violations_labels[violations_labels > -1]
self.violations_labels = violations_labels
## Set clusters with too few ROIs to -1
u, c = np.unique(labels, return_counts=True)
labels[np.isin(labels, u[c<2])] = -1
labels = helpers.squeeze_integers(labels)
self.labels = labels
return self.labels
[docs]
def fit_sequentialHungarian(
self,
d_conj: scipy.sparse.csr_array,
session_bool: np.ndarray,
thresh_cost: float = 0.95,
) -> np.ndarray:
"""
Applies CaImAn's method for clustering.
For further details, please refer to:
* [CaImAn's paper](https://elifesciences.org/articles/38173#s4)
* [CaImAn's repository](https://github.com/flatironinstitute/CaImAn)
* [Relevant script in CaImAn's repository](https://github.com/flatironinstitute/CaImAn/blob/master/caiman/base/rois.py)
Args:
d_conj (scipy.sparse.csr_array):
Distance matrix.
Shape: *(n_rois, n_rois)*
session_bool (np.ndarray):
Boolean array indicating which ROIs are in which sessions.
Shape: *(n_rois, n_sessions)*
thresh_cost (float):
Threshold below which ROI pairs are considered potential matches.
(Default is *0.95*)
Returns:
(np.ndarray):
labels (np.ndarray):
Cluster labels. Shape: *(n_rois,)*
"""
## Store parameter (but not data) args as attributes
self.params['fit_sequentialHungarian'] = self._locals_to_params(
locals_dict=locals(),
keys=['thresh_cost',],)
print(f"Clustering with CaImAn's sequential Hungarian algorithm method...") if self._verbose else None
def find_matches(D_s):
matches = []
costs = []
for ii, D in enumerate(D_s):
# we make a copy not to set changes in the original
DD = D.copy()
if np.sum(np.where(np.isnan(DD))) > 0:
raise Exception('Distance Matrix contains invalid value NaN')
# we do the hungarian
indexes = scipy.optimize.linear_sum_assignment(DD)
indexes2 = [(ind1, ind2) for ind1, ind2 in zip(indexes[0], indexes[1])]
matches.append(indexes)
total = []
# we want to extract those informations from the hungarian algo
for row, column in indexes2:
value = DD[row, column]
total.append(value)
costs.append(total)
# send back the results in the format we want
return matches, costs
n_roi = session_bool.sum(0)
n_roi_cum = np.concatenate(([0], np.cumsum(n_roi)))
matchings = []
matchings.append(list(range(n_roi[0])))
idx_union = np.arange(n_roi[0])
for i_sesh in tqdm(range(1,len(n_roi))):
idx_sess = np.arange(n_roi_cum[i_sesh], n_roi_cum[i_sesh+1])
d_sub = d_conj[idx_sess][:, idx_union]
D = np.ones((len(idx_sess), len(idx_union)))*np.logical_not((d_sub != 0).toarray())*1 + d_sub.toarray()
D = [D]
matches, costs = find_matches(D)
matches = matches[0]
costs = costs[0]
# store indices
idx_tp = np.where(np.array(costs) < thresh_cost)[0]
if len(idx_tp) > 0:
matched_ROIs1 = matches[0][idx_tp] # ground truth
matched_ROIs2 = matches[1][idx_tp] # algorithm - comp
non_matched1 = np.setdiff1d(list(range(D[0].shape[0])), matches[0][idx_tp])
non_matched2 = np.setdiff1d(list(range(D[0].shape[1])), matches[1][idx_tp])
TP = np.sum(np.array(costs) < thresh_cost) * 1.
else:
TP = 0.
matched_ROIs1 = []
matched_ROIs2 = []
non_matched1 = list(range(D[0].shape[0]))
non_matched2 = list(range(D[0].shape[1]))
# compute precision and recall
FN = D[0].shape[0] - TP
FP = D[0].shape[1] - TP
TN = 0
performance = dict()
performance['recall'] = TP / (TP + FN)
performance['precision'] = TP / (TP + FP)
performance['accuracy'] = (TP + TN) / (TP + FP + FN + TN)
performance['f1_score'] = 2 * TP / (2 * TP + FP + FN)
mat_sess, mat_un, nm_sess, nm_un, performance, A2_len = matched_ROIs1, matched_ROIs2, non_matched1, non_matched2, performance, len(idx_union)
idx_union[mat_un] = idx_sess[mat_sess]
idx_union = np.concatenate((idx_union, idx_sess[nm_sess]))
new_match = np.zeros(n_roi[i_sesh], dtype=int)
new_match[mat_sess] = mat_un
new_match[nm_sess] = range(A2_len, len(idx_union))
matchings.append(new_match.tolist())
self.seqHung_performance = performance
labels = np.concatenate(matchings)
u, c = np.unique(labels, return_counts=True)
labels[np.isin(labels, u[c == 1])] = -1
labels = helpers.squeeze_integers(labels)
self.labels = labels
return self.labels
[docs]
def make_conjunctive_distance_matrix(
self,
similarities: Dict[str, scipy.sparse.csr_array],
mixing_params: Dict[str, Any],
) -> Tuple[scipy.sparse.csr_array, scipy.sparse.csr_array, Dict[str, torch.Tensor]]:
"""
Makes a conjunctive distance matrix from the similarity matrices
using the given mixing parameters.
RH 2023 / 2025
Args:
similarities (Dict[str, scipy.sparse.csr_array]):
Dict mapping metric name to sparse similarity matrix.
mixing_params (Dict[str, Any]):
Mixing parameters dict. Expected keys: \n
* ``power_<name>`` (float): Power for each metric.
Defaults to 1.0 if not present.
* ``sig_<name>_kwargs`` (Dict[str, float]): Sigmoid
parameters ``{'mu': float, 'b': float}`` per metric.
``None`` or absent means no sigmoid.
* ``p_norm`` (float): p-norm exponent for combining
activated similarities.
Returns:
(Tuple): Tuple containing:
dConj (scipy.sparse.csr_array):
Conjunctive distance matrix (1 - sConj).
sConj (scipy.sparse.csr_array):
Conjunctive similarity matrix.
activated_data (Dict[str, torch.Tensor]):
Per-metric activated similarity data arrays.
"""
assert len(similarities) > 0, 'At least one similarity matrix must be provided.'
## Store parameter (but not data) args as attributes
self.params['make_conjunctive_distance_matrix'] = self._locals_to_params(
locals_dict=locals(),
keys=['mixing_params'],
)
p_norm = mixing_params.get('p_norm', 1.0) ## scalar
p_norm = 1e-9 if p_norm == 0 else p_norm ## avoid division by zero
## Activate each metric: sigmoid → clamp → power
## Each metric's .data is a 1D array of shape (nnz,)
activated_data = {} ## Dict[str, torch.Tensor(nnz,)]
s_list = []
for name, s in similarities.items():
power = mixing_params.get(f'power_{name}', 1.0)
sig_kwargs = mixing_params.get(f'sig_{name}_kwargs', None)
activated = self._activation_function(s.data, sig_kwargs, power)
activated_data[name] = activated
s_list.append(activated)
## p-norm combination: sConj = (mean(s_k^p))^(1/p)
sConj_data = self._pNorm(s_list=s_list, p=p_norm) ## shape: (nnz,)
## Build sparse sConj using the first similarity's sparsity pattern
template = next(iter(similarities.values())) ## csr_array (n_roi, n_roi)
sConj = template.copy()
sConj.data = sConj_data.numpy()
## Distance = 1 - similarity
dConj = sConj.copy()
dConj.data = 1 - dConj.data
return dConj, sConj, activated_data
def _activation_function(
self,
s: Optional[torch.Tensor] = None,
sig_kwargs: Optional[Dict[str, float]] = {'mu':0.0, 'b':1.0},
power: Optional[float] = 1
) -> Optional[torch.Tensor]:
"""
Applies an activation function to a similarity matrix.
Args:
s (Optional[torch.Tensor]):
The input similarity matrix. If ``None``, the function returns
``None``. (Default is ``None``)
sig_kwargs (Dict[str, float]):
Keyword arguments for the sigmoid function applied to the
similarity matrix. See helpers.generalised_logistic_function for
details. (Default is {'mu':0.0, 'b':1.0})
power (Optional[float]):
Power to which to raise the similarity. If ``None``, the power
operation is not applied. (Default is *1*)
Returns:
(Optional[torch.Tensor]):
Activated similarity matrix. Returns ``None`` if the input
similarity matrix is ``None``.
"""
if s is None:
return None
s = torch.as_tensor(s, dtype=torch.float32)
## make functions such that if the param is None, then no operation is applied
fn_sigmoid = lambda x, params: helpers.generalised_logistic_function(x, **params) if params is not None else x
fn_power = lambda x, p: x ** p if p is not None else x
return fn_power(torch.clamp(fn_sigmoid(s, sig_kwargs), min=0), power)
def _pNorm(
self,
s_list: List[Optional[torch.Tensor]],
p: float
) -> torch.Tensor:
"""
Calculate the p-norm of a list of similarity matrices.
Args:
s_list (List[Optional[torch.Tensor]]):
List of similarity matrices.
p (float):
p-norm to use.
Returns:
(torch.Tensor):
p-norm of the list of similarity matrices.
"""
s_list_noNones = [s for s in s_list if s is not None]
return (torch.mean(torch.stack(s_list_noNones, axis=0)**p, dim=0))**(1/p)
[docs]
def plot_similarity_relationships(
self,
max_samples: int = 1000000,
kwargs_scatter: Dict[str, Union[int, float]] = {'s': 1, 'alpha': 0.1},
mixing_params: Optional[Dict[str, Any]] = None,
) -> Tuple[plt.figure, plt.axes]:
"""
Plot pairwise similarity relationships for all N*(N-1)/2 metric pairs.
Each subplot shows one pair of metrics, colored by conjunctive distance.
Args:
max_samples (int):
Maximum number of samples to plot.
kwargs_scatter (Dict[str, Union[int, float]]):
Keyword arguments for ``matplotlib.pyplot.scatter``.
mixing_params (Optional[Dict[str, Any]]):
Mixing parameters for ``make_conjunctive_distance_matrix``.
If ``None``, uses ``self.best_params`` if available, else
defaults.
Returns:
(Tuple[matplotlib.pyplot.figure, matplotlib.pyplot.axes]):
fig, axs: Figure and axes objects.
"""
if mixing_params is None:
mixing_params = getattr(self, 'best_params', {'p_norm': -4.0})
dConj, sConj, activated_data = self.make_conjunctive_distance_matrix(
similarities=self.similarities,
mixing_params=mixing_params,
)
## Generate all N*(N-1)/2 metric pairs
import itertools
metric_names = list(self.similarities.keys())
pairs = list(itertools.combinations(metric_names, 2))
n_pairs = max(len(pairs), 1)
## Subsample for plotting
idx_rand = np.floor(
np.random.rand(min(max_samples, len(dConj.data))) * len(dConj.data)
).astype(int)
d_conj_sub = dConj.data[idx_rand]
fig, axs = plt.subplots(nrows=1, ncols=n_pairs, figsize=(7 * n_pairs, 4))
if n_pairs == 1:
axs = [axs]
fig.suptitle('Similarity relationships', fontsize=16)
for i, (name_x, name_y) in enumerate(pairs):
x_data = activated_data[name_x][idx_rand]
y_data = activated_data[name_y][idx_rand]
axs[i].scatter(x_data, y_data, c=d_conj_sub, **kwargs_scatter)
axs[i].set_xlabel(f'sim {name_x}')
axs[i].set_ylabel(f'sim {name_y}')
return fig, axs
[docs]
def plot_distSame(self, mixing_params: Optional[dict] = None) -> None:
"""
Plot the estimated distribution of the pairwise similarities between
matched ROI pairs of ROIs.
Args:
mixing_params (Optional[dict]):
Mixing parameters for ``make_conjunctive_distance_matrix``.
If ``None``, the function uses the object's best parameters.
(Default is ``None``)
"""
params = mixing_params if mixing_params is not None else self.best_params
dConj, sConj, activated_data = self.make_conjunctive_distance_matrix(
similarities=self.similarities,
mixing_params=params,
)
dens_same_crop, dens_same, dens_diff, dens_all, edges, d_crossover = self._separate_diffSame_distributions(dConj)
if edges is None:
print('No crossover found, not plotting')
return None
fig = plt.figure()
plt.stairs(dens_same, edges, linewidth=5)
plt.stairs(dens_same_crop, edges, linewidth=3)
plt.stairs(dens_diff, edges)
plt.stairs(dens_all, edges)
plt.axvline(d_crossover, color='k', linestyle='--')
plt.ylim([dens_same.max()*-0.5, dens_same.max()*1.5])
plt.title('Pairwise similarities')
plt.xlabel('distance or prob(different)')
plt.ylabel('counts or density')
plt.legend(['same', 'same (cropped)', 'diff', 'all', 'crossover'])
return fig
def _fn_smooth(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Smooth a 1D tensor with a boxcar convolution.
Args:
x (torch.Tensor):
1D tensor to be smoothed.
Returns:
(torch.Tensor):
Smoothed tensor, same shape as ``x``.
"""
return helpers.Convolver_1d(
kernel=torch.ones(self.smooth_window),
length_x=self.n_bins,
pad_mode='same',
correct_edge_effects=True,
device='cpu',
).convolve(x)
####################################################################
## Shared histogram overlap computation
####################################################################
def _compute_histogram_overlap(
self,
distances: torch.Tensor,
intra_indices: torch.Tensor,
edges: torch.Tensor,
smoother: 'helpers.Convolver_1d',
scale_factor: float,
) -> Tuple[float, torch.Tensor, torch.Tensor]:
"""
Compute the hard histogram overlap loss between the estimated
'same' and 'different' distance distributions.
Shared core used by :meth:`_separate_diffSame_distributions` and
:meth:`_find_optimal_parameters_DE` (scalar objective).
The 'different' distribution is estimated by scaling the intra-session
(known-different) histogram. The 'same' distribution is the residual
after subtracting the scaled intra counts from all counts, clamped
non-negative and then smoothed. The loss is the dot product of the
two estimated distributions (overlap area).
If no valid crossover point exists (i.e. the two distributions never
separate), ``loss`` is set to ``1e6`` as a penalty.
RH 2025
Args:
distances (torch.Tensor):
1D float tensor of distance values, shape ``(n,)``.
intra_indices (torch.Tensor):
1D int64 tensor of indices into ``distances`` corresponding
to intra-session (known-different) pairs.
edges (torch.Tensor):
1D float tensor of bin edges, shape ``(n_bins + 1,)``.
smoother (helpers.Convolver_1d):
Pre-built 1D convolver for smoothing the 'same' distribution.
scale_factor (float):
``n_all / n_intra`` — multiplier that scales the intra counts
up to the full population size.
Returns:
(Tuple[float, torch.Tensor, torch.Tensor]):
loss (float):
Overlap area (dot product of dens_same and dens_diff),
or ``1e6`` if no valid crossover exists.
dens_same (torch.Tensor):
Smoothed estimated 'same' distribution, shape ``(n_bins,)``.
dens_diff (torch.Tensor):
Scaled intra-session distribution (estimated 'different'),
shape ``(n_bins,)``.
"""
## Histogram all distances and intra-session distances
counts_all, _ = torch.histogram(distances, edges)
counts_intra, _ = torch.histogram(distances[intra_indices], edges)
## Scale intra counts to estimate the full 'different' distribution
dens_diff = counts_intra * scale_factor
## 'Same' distribution = residual, clamped non-negative, then smoothed
dens_same = smoother.convolve(torch.clamp(counts_all - dens_diff, min=0))
## Penalize if no valid crossover exists between the two distributions
dens_deriv = dens_diff - dens_same
dens_deriv[int(dens_diff.argmax().item()):] = 0
if not (dens_deriv < 0).any():
return 1e6, dens_same, dens_diff
return float((dens_same * dens_diff).sum().item()), dens_same, dens_diff
def _separate_diffSame_distributions(
self,
d_conj: scipy.sparse.csr_array,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, float]:
"""
Estimate the 'same' and 'different' distance distributions from a
conjunctive distance matrix.
The 'same' distribution is estimated as the residual after subtracting
the scaled intra-session (known-different) distribution from the overall
distribution. Delegates core histogram overlap computation to
:meth:`_compute_histogram_overlap`.
Args:
d_conj (scipy.sparse.csr_array):
Conjunctive distance matrix.
Returns:
(tuple): tuple containing:
dens_same_crop (np.ndarray):
'Same' distribution with values below the crossover point
zeroed out.
dens_same (np.ndarray):
Un-cropped smoothed 'same' distribution.
dens_diff (np.ndarray):
Scaled intra-session 'different' distribution.
dens_all (np.ndarray):
Raw histogram counts over all pairs.
edges (np.ndarray):
Bin edges used for all histograms.
d_crossover (float):
Distance at which the 'same' and 'different' distributions
cross over.
"""
## Bin edges covering the full [0, 1] distance range
edges = torch.linspace(start=0, end=1, steps=self.n_bins + 1, dtype=torch.float32)
## Extract all-pairs distances and compute raw histogram
dist_all = torch.as_tensor(d_conj.data, dtype=torch.float32)
dens_all, _ = torch.histogram(dist_all, edges)
## Extract intra-session (known-different) distances via s_sesh_inv mask
d_intra = d_conj.multiply(self.s_sesh_inv)
d_intra.eliminate_zeros()
if len(d_intra.data) == 0:
return None, None, None, None, None, None
dist_intra = torch.as_tensor(d_intra.data, dtype=torch.float32)
## Scale factor: ratio of all pairs to intra-session pairs
n_all = int(dens_all.sum().item())
n_intra = dist_intra.shape[0]
scale_factor = float(n_all) / max(n_intra, 1)
## Compute scaled intra (different) distribution
counts_intra, _ = torch.histogram(dist_intra, edges)
dens_diff = counts_intra * scale_factor
## Smoothed residual (same) — uses cached smoother in _fn_smooth
dens_same = self._fn_smooth(torch.clamp(dens_all - dens_diff, min=0))
## Locate crossover: last bin before dens_diff peak where diff > same
dens_deriv = dens_diff - dens_same
dens_deriv[int(dens_diff.argmax().item()):] = 0
crossover_candidates = torch.where(dens_deriv < 0)[0]
if crossover_candidates.shape[0] == 0:
return None, None, None, None, None, None
idx_crossover = int(crossover_candidates[-1].item()) + 1
d_crossover = float(edges[idx_crossover].item())
## Zero out 'same' distribution below the crossover point
dens_same_crop = dens_same.clone()
dens_same_crop[idx_crossover:] = 0
return dens_same_crop, dens_same, dens_diff, dens_all, edges, d_crossover
def _extract_hdbscan_quality_metrics(self) -> Dict:
"""
Extract HDBSCAN quality metrics, handling differences between legacy
hdbscan and fast_hdbscan backends.
Legacy hdbscan (with attach_fully_connected_node) stores an extra
trailing element for the fully connected node that must be stripped.
fast_hdbscan does not add extra nodes but exposes additional
attributes: ``_core_distances`` (per-point core distance) and
``_min_spanning_tree`` (MST edge array).
Returns:
(Dict):
Dictionary with keys: \n
* ``sample_probabilities``: list of floats, per-ROI membership
strength.
* ``sample_outlierScores``: list of floats (legacy) or ``None``
(fast_hdbscan).
* ``sample_coreDistances``: list of floats (fast_hdbscan) or
``None`` (legacy). Per-point core distance used for mutual
reachability.
* ``mst_edge_weights``: list of floats (fast_hdbscan) or
``None`` (legacy). Sorted MST edge weights; useful for
diagnosing cluster separation.
"""
def to_list_of_floats(x):
return [float(i) for i in x] if x is not None else None
used_fcn = getattr(self, '_fit_used_fully_connected_node', True)
## Probabilities
probs = self.hdbs.probabilities_
if used_fcn:
probs = probs[:-1]
result = {
'sample_probabilities': to_list_of_floats(probs),
}
## Outlier scores (only available in legacy hdbscan, not fast_hdbscan)
if hasattr(self.hdbs, 'outlier_scores_'):
scores = self.hdbs.outlier_scores_
if used_fcn:
scores = scores[:-1]
result['sample_outlierScores'] = to_list_of_floats(scores)
else:
result['sample_outlierScores'] = None
## Core distances (fast_hdbscan only)
core_dists = getattr(self.hdbs, '_core_distances', None)
if core_dists is not None:
if used_fcn:
core_dists = core_dists[:-1]
result['sample_coreDistances'] = to_list_of_floats(core_dists)
else:
result['sample_coreDistances'] = None
## MST edge weights (fast_hdbscan only)
mst = getattr(self.hdbs, '_min_spanning_tree', None)
if mst is not None:
## MST is (n-1, 3) array with columns [src, dst, weight]
result['mst_edge_weights'] = to_list_of_floats(np.sort(mst[:, 2]))
else:
result['mst_edge_weights'] = None
return result
[docs]
def compute_quality_metrics(
self,
sim_mat: Optional[object] = None,
dist_mat: Optional[object] = None,
labels: Optional[np.ndarray] = None,
) -> Dict:
"""
Computes quality metrics of the dataset.
RH 2023
Args:
sim_mat (Optional[object]):
Similarity matrix of shape *(n_samples, n_samples)*.
If ``None`` then self.sConj must exist. (Default is ``None``)
dist_mat (Optional[object]):
Distance matrix of shape *(n_samples, n_samples)*.
If ``None`` then self.dConj must exist. (Default is ``None``)
labels (Optional[np.ndarray]):
Cluster labels of shape *(n_samples,)*.
If ``None``, then self.labels must exist. (Default is ``None``)
Returns:
(Dict):
quality_metrics (Dict):
Quality metrics dictionary that includes:
'cluster_intra_means', 'cluster_intra_mins',
'cluster_intra_maxs', 'cluster_silhouette',
'sample_silhouette', and other metrics if available.
"""
if sim_mat is None:
assert hasattr(self, 'sConj'), "self.sConj does not exist. Run self.find_optimal_parameters_for_pruning() first or specify sim_mat."
sim_mat = self.sConj
if dist_mat is None:
assert hasattr(self, 'dConj'), "self.dConj does not exist. Run self.find_optimal_parameters_for_pruning() first or specify dist_mat."
dist_mat = self.dConj
if labels is None:
assert hasattr(self, 'labels'), "self.labels does not exist. Run self.find_optimal_parameters_for_pruning() first or specify labels."
labels = self.labels
assert scipy.sparse.issparse(sim_mat), "sim_mat must be a scipy.sparse.csr_array."
assert scipy.sparse.issparse(dist_mat), "dist_mat must be a scipy.sparse.csr_array."
labels_unique, cs_intra_means, cs_intra_mins, cs_intra_maxs, cs_sil = cluster_quality_metrics(
sim=sim_mat,
labels=labels,
)
import sklearn
import sparse
d_dense = sparse.COO(dist_mat.copy().tocsr()).astype(np.float16)
d_dense.fill_value = (dist_mat.data.max() - dist_mat.data.min()).astype(np.float16) * 10
d_dense = d_dense.todense()
np.fill_diagonal(d_dense, 0)
## Number of labels must be at least 2
if len(np.unique(labels)) < 2:
warnings.warn(f"Silhouette samples calculation requires at least 2 labels. Returning None. Found {len(np.unique(labels))} labels.")
rs_sil = None
else:
rs_sil = sklearn.metrics.silhouette_samples(X=d_dense, labels=labels, metric='precomputed')
def to_list_of_floats(x):
return [float(i) for i in x] if x is not None else None
## Extract HDBSCAN-specific metrics
hdbscan_metrics = self._extract_hdbscan_quality_metrics() if hasattr(self, 'hdbs') else None
self.quality_metrics = util.JSON_Dict({
'cluster_labels_unique': to_list_of_floats(labels_unique),
'cluster_intra_means': to_list_of_floats(cs_intra_means),
'cluster_intra_mins': to_list_of_floats(cs_intra_mins),
'cluster_intra_maxs': to_list_of_floats(cs_intra_maxs),
'cluster_silhouette': to_list_of_floats(cs_sil),
'sample_silhouette': to_list_of_floats(rs_sil),
'sample_probabilities': hdbscan_metrics['sample_probabilities'] if hdbscan_metrics else None,
'hdbscan': hdbscan_metrics,
'sequentialHungarian': {
'performance_recall': float(self.seqHung_performance['recall']),
'performance_precision': float(self.seqHung_performance['precision']),
'performance_f1': float(self.seqHung_performance['f1_score']),
'performance_accuracy': float(self.seqHung_performance['accuracy']),
} if hasattr(self, 'seqHung_performance') else None,
})
return self.quality_metrics
####################################################################
####################################################################
##
## LEGACY / COMPARISON METHODS
##
## The methods below are retained for backward compatibility and
## benchmarking. They are NOT used by the default pipeline.
##
## The recommended approach is find_optimal_parameters_for_pruning()
## which calls _find_optimal_parameters_DE(freeze_sigmoid=True).
##
####################################################################
####################################################################
def _find_optimal_parameters_for_pruning_optuna(
self,
kwargs_findParameters: Dict[str, Union[int, float, bool]] = {
'n_patience': 100,
'tol_frac': 0.05,
'max_trials': 350,
'max_duration': 60*10,
'value_stop': 0.0,
},
bounds_findParameters: Dict[str, List[float]] = {
'power_nn': [0.0, 2.],
'power_swt': [0.0, 2.],
'p_norm': [-5, -0.1],
'sig_nn_kwargs_mu': [0., 1.0],
'sig_nn_kwargs_b': [0.1, 1.5],
'sig_swt_kwargs_mu': [0., 1.0],
'sig_swt_kwargs_b': [0.1, 1.5],
},
n_jobs_findParameters: int = -1,
n_bins: Optional[int] = None,
smoothing_window_bins: Optional[int] = None,
seed=None,
) -> Dict:
"""
**LEGACY** — Original Optuna TPE optimizer (7-param). Superseded by
:meth:`find_optimal_parameters_for_pruning` which uses NB calibration
+ freeze-sigmoid DE and achieves ~3.6x better separation quality.
Requires ``optuna`` to be installed.
RH 2023
Args:
kwargs_findParameters (Dict[str, Union[int, float, bool]]):
Keyword arguments for the Convergence_checker class __init__.
bounds_findParameters (Dict[str, List[float]]):
Bounds for the 7 parameters to be optimized.
n_jobs_findParameters (int):
Number of parallel Optuna jobs. ``-1`` = all cores.
n_bins (Optional[int]):
Overwrites ``n_bins`` from ``__init__``.
smoothing_window_bins (Optional[int]):
Overwrites ``smoothing_window_bins`` from ``__init__``.
seed (Optional[int]):
Random seed.
Returns:
(Dict):
kwargs_makeConjunctiveDistanceMatrix_best (Dict):
Optimal parameters for
:meth:`make_conjunctive_distance_matrix`.
"""
import optuna
self.params['find_optimal_parameters_for_pruning'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'kwargs_findParameters', 'bounds_findParameters',
'n_jobs_findParameters', 'n_bins', 'smoothing_window_bins', 'seed',
],
)
self.n_bins = self.n_bins if n_bins is None else n_bins
self.smooth_window = self.smooth_window if smoothing_window_bins is None else smoothing_window_bins
self.bounds_findParameters = bounds_findParameters
self._seed = seed
np.random.seed(self._seed)
print('Finding mixing parameters using automated hyperparameter tuning...') if self._verbose else None
optuna.logging.set_verbosity(optuna.logging.WARNING)
self.checker = helpers.Convergence_checker_optuna(verbose=self._verbose >= 2, **kwargs_findParameters)
prog_bar = helpers.OptunaProgressBar(
n_trials=kwargs_findParameters['max_trials'],
mininterval=5.0,
)
self.study = optuna.create_study(direction='minimize', sampler=optuna.samplers.TPESampler(
n_startup_trials=kwargs_findParameters['n_patience'] // 2,
seed=self._seed,
))
self.study.optimize(
func=self._objectiveFn_distSameMagnitude,
n_jobs=n_jobs_findParameters,
callbacks=[self.checker.check, prog_bar],
n_trials=kwargs_findParameters['max_trials'],
show_progress_bar=False,
)
self.best_params = self.study.best_params.copy()
[self.best_params.pop(p) for p in [
'sig_nn_kwargs_mu', 'sig_nn_kwargs_b',
'sig_swt_kwargs_mu', 'sig_swt_kwargs_b',
] if p in self.best_params.keys()]
self.best_params['sig_nn_kwargs'] = {
'mu': self.study.best_params['sig_nn_kwargs_mu'],
'b': self.study.best_params['sig_nn_kwargs_b'],
}
self.best_params['sig_swt_kwargs'] = {
'mu': self.study.best_params['sig_swt_kwargs_mu'],
'b': self.study.best_params['sig_swt_kwargs_b'],
}
self.kwargs_makeConjunctiveDistanceMatrix_best = {
'power_sf': None, 'power_nn': None, 'power_swt': None,
'p_norm': None, 'sig_sf_kwargs': None,
'sig_nn_kwargs': None, 'sig_swt_kwargs': None,
}
self.kwargs_makeConjunctiveDistanceMatrix_best.update(self.best_params)
print(f'Completed automatic parameter search. Best value found: {self.study.best_value} with parameters {self.best_params}') if self._verbose else None
return self.kwargs_makeConjunctiveDistanceMatrix_best
def _objectiveFn_distSameMagnitude(
self,
trial: object,
) -> float:
"""
**LEGACY** — Optuna objective function for histogram overlap loss.
Used by :meth:`_find_optimal_parameters_for_pruning_optuna`.
RH 2023
"""
power_NN = trial.suggest_float('power_nn', *self.bounds_findParameters['power_nn'], log=False)
power_SWT = trial.suggest_float('power_swt', *self.bounds_findParameters['power_swt'], log=False)
p_norm = trial.suggest_float('p_norm', *self.bounds_findParameters['p_norm'], log=False)
sig_NN_kwargs = {
'mu': trial.suggest_float('sig_nn_kwargs_mu', *self.bounds_findParameters['sig_nn_kwargs_mu'], log=False),
'b': trial.suggest_float('sig_nn_kwargs_b', *self.bounds_findParameters['sig_nn_kwargs_b'], log=False),
}
sig_SWT_kwargs = {
'mu': trial.suggest_float('sig_swt_kwargs_mu', *self.bounds_findParameters['sig_swt_kwargs_mu'], log=False),
'b': trial.suggest_float('sig_swt_kwargs_b', *self.bounds_findParameters['sig_swt_kwargs_b'], log=False),
}
mixing_params = {
'power_sf': 1.0, 'power_nn': power_NN, 'power_swt': power_SWT,
'p_norm': p_norm,
'sig_nn_kwargs': sig_NN_kwargs, 'sig_swt_kwargs': sig_SWT_kwargs,
}
dConj, sConj, activated_data = self.make_conjunctive_distance_matrix(
similarities=self.similarities,
mixing_params=mixing_params,
)
dens_same_crop, dens_same, dens_diff, dens_all, edges, d_crossover = self._separate_diffSame_distributions(dConj)
if dens_same_crop is None:
return 0
return (dens_same * dens_diff).sum().item()
## Lazy-compiled numba kernel for weighted Jaccard. Compiled on first call
## to avoid numba import overhead (subprocess spam on macOS) at module load.
_weighted_jaccard_csr_kernel = None
def _get_weighted_jaccard_kernel():
"""
Lazily compile and cache the numba kernel for weighted Jaccard.
Returns the compiled ``_weighted_jaccard_csr_kernel`` function.
"""
global _weighted_jaccard_csr_kernel
if _weighted_jaccard_csr_kernel is not None:
return _weighted_jaccard_csr_kernel
import numba
@numba.njit(cache=True)
def kernel(indptr, indices, data, out_data):
"""
Numba kernel: weighted Jaccard (Ruzicka) similarity for each edge in a
symmetric CSR similarity matrix.
For each stored edge (i, j), computes:
J_w(i,j) = Σ_k min(s_ik, s_jk) / Σ_k max(s_ik, s_jk)
where the sum runs over all k in the union of neighbors of i and j,
**excluding k = i and k = j** (the direct edge endpoints).
Assumptions:
- Symmetric matrix (s_ij = s_ji for all stored pairs).
- Zero diagonal (no self-loops stored).
- Indices within each row are sorted ascending (CSR standard).
- All data values are non-negative.
"""
n = len(indptr) - 1
for i in range(n):
for ptr_ij in range(indptr[i], indptr[i + 1]):
j = indices[ptr_ij]
## Merge-scan sorted neighbor lists of rows i and j
pi = indptr[i]
pi_end = indptr[i + 1]
pj = indptr[j]
pj_end = indptr[j + 1]
sum_min = 0.0
sum_max = 0.0
while pi < pi_end and pj < pj_end:
ki = indices[pi]
kj = indices[pj]
if ki == kj:
## Shared neighbor — skip if it is one of the edge endpoints
if ki != i and ki != j:
vi = data[pi]
vj = data[pj]
if vi < vj:
sum_min += vi
sum_max += vj
else:
sum_min += vj
sum_max += vi
pi += 1
pj += 1
elif ki < kj:
## Neighbor of i only — skip if it is j
if ki != j:
sum_max += data[pi]
pi += 1
else:
## Neighbor of j only — skip if it is i
if kj != i:
sum_max += data[pj]
pj += 1
## Drain remaining from row i
while pi < pi_end:
if indices[pi] != j:
sum_max += data[pi]
pi += 1
## Drain remaining from row j
while pj < pj_end:
if indices[pj] != i:
sum_max += data[pj]
pj += 1
if sum_max > 0.0:
out_data[ptr_ij] = sum_min / sum_max
else:
out_data[ptr_ij] = 0.0
_weighted_jaccard_csr_kernel = kernel
return _weighted_jaccard_csr_kernel
[docs]
def weighted_jaccard_similarity(
s: scipy.sparse.csr_array,
) -> scipy.sparse.csr_array:
"""
Compute weighted Jaccard (Ruzicka) similarity from a sparse similarity
graph. For each pair (i, j) in the input, replaces the direct similarity
``s_ij`` with a neighborhood-based structural similarity:
.. math::
J_w(i, j) = \\frac{\\sum_{k \\neq i,j} \\min(s_{ik}, s_{jk})}
{\\sum_{k \\neq i,j} \\max(s_{ik}, s_{jk})}
where the sum is over all ``k`` in the union of non-zero neighbors of
``i`` and ``j``, excluding ``k = i`` and ``k = j``.
This is a second-order similarity: two nodes score high if they share
many strong connections to the same neighbors. Acts as a denoising step
that amplifies community structure. Used in SNN clustering (Seurat),
Louvain/Leiden, and UMAP local connectivity.
Uses a numba-accelerated merge-scan over sorted CSR rows.
Complexity: O(nnz * avg_degree).
Args:
s (scipy.sparse.csr_array):
Sparse symmetric similarity matrix. Shape: *(n, n)*. Must have
non-negative values, zero diagonal (not stored), and sorted
indices within each row (standard CSR convention).
Returns:
(scipy.sparse.csr_array):
s_jaccard (scipy.sparse.csr_array):
Weighted Jaccard similarity matrix. Same sparsity pattern as
input. Values in [0, 1]. Symmetric if input is symmetric.
Zero diagonal.
"""
assert isinstance(s, scipy.sparse.csr_array), (
f"Expected scipy.sparse.csr_array, got {type(s)}"
)
assert s.shape[0] == s.shape[1], "Matrix must be square."
## Ensure sorted indices (required for merge-scan)
s_sorted = s
if not s.has_sorted_indices:
s_sorted = s.copy()
s_sorted.sort_indices()
data_f64 = s_sorted.data.astype(np.float64)
out_data = np.empty(len(data_f64), dtype=np.float64)
kernel = _get_weighted_jaccard_kernel()
kernel(
indptr=s_sorted.indptr,
indices=s_sorted.indices,
data=data_f64,
out_data=out_data,
)
s_jaccard = s_sorted.copy()
s_jaccard.data = out_data.astype(s.data.dtype)
return s_jaccard
## Lazy-compiled numba kernel for noise rescue Kruskal. Same pattern as
## weighted Jaccard: compiled on first call to avoid numba import overhead.
_noise_rescue_kruskal_kernel = None
def _get_noise_rescue_kernel():
"""
Lazily compile and cache the numba kernel for noise rescue Kruskal.
Returns the compiled ``_noise_rescue_kruskal_kernel`` function.
"""
global _noise_rescue_kruskal_kernel
if _noise_rescue_kruskal_kernel is not None:
return _noise_rescue_kruskal_kernel
import numba
@numba.njit(cache=True)
def kernel(
indptr, ## int32[:], CSR row pointers
indices, ## int32[:], CSR column indices
data, ## float64[:], CSR edge distances
labels, ## int64[:], Phase 1 labels (-1 = noise)
group_labels, ## int32[:], session index per ROI
n_groups, ## int, number of sessions
d_cutoff, ## float64, maximum distance for edge acceptance
):
"""
Kruskal-style noise rescue with DSU + bitmask cannot-link constraints.
Given Phase 1 cluster labels, processes edges from the distance graph
where at least one endpoint is noise (label == -1), sorted by distance
ascending. Merges are accepted if:
1. d <= d_cutoff
2. endpoints are in different DSU components
3. no session-conflict (bitmask AND == 0)
The DSU is pre-initialized with Phase 1 clusters: all members of the
same cluster are pre-merged and their component bitmask is the OR of
their session bits.
After processing, labels are extracted:
- Components containing Phase 1 cluster members inherit that label
- Components of only ex-noise points with size >= 2 get new labels
(starting from max_existing_label + 1)
- Remaining singletons stay -1
Returns new_labels: int64[:] of length n_rois.
"""
n = len(indptr) - 1 ## n_rois
## -- DSU arrays --
parent = np.arange(n, dtype=np.int64)
rank = np.zeros(n, dtype=np.int32)
## -- Bitmask arrays: comp_mask[root, word] --
n_words = (n_groups + 63) // 64
comp_mask = np.zeros((n, n_words), dtype=np.uint64)
## Initialize bitmasks from group_labels
for i in range(n):
g = group_labels[i]
if g >= 0:
w = g // 64
b = np.uint64(g % 64)
comp_mask[i, w] = comp_mask[i, w] | (np.uint64(1) << b)
## -- Pre-merge Phase 1 clusters --
## For each cluster label > -1, union all its members.
## First pass: find max label to size the bookkeeping.
max_label = -1
for i in range(n):
if labels[i] > max_label:
max_label = labels[i]
if max_label >= 0:
## For each cluster, track first member seen as the anchor
cluster_anchor = np.full(max_label + 1, -1, dtype=np.int64)
for i in range(n):
lbl = labels[i]
if lbl < 0:
continue
if cluster_anchor[lbl] < 0:
cluster_anchor[lbl] = i
else:
## Union i with the anchor
anchor = cluster_anchor[lbl]
## Find root of anchor
root_a = anchor
while parent[root_a] != root_a:
root_a = parent[root_a]
curr = anchor
while curr != root_a:
nxt = parent[curr]
parent[curr] = root_a
curr = nxt
## Find root of i
root_i = i
while parent[root_i] != root_i:
root_i = parent[root_i]
curr = i
while curr != root_i:
nxt = parent[curr]
parent[curr] = root_i
curr = nxt
if root_a != root_i:
## Union by rank
if rank[root_a] > rank[root_i]:
new_root = root_a
old_root = root_i
elif rank[root_a] < rank[root_i]:
new_root = root_i
old_root = root_a
else:
new_root = root_a
old_root = root_i
rank[new_root] += 1
parent[old_root] = new_root
## Merge bitmasks
for ww in range(n_words):
comp_mask[new_root, ww] = (
comp_mask[new_root, ww] | comp_mask[old_root, ww]
)
## -- Collect edges where at least one endpoint is noise --
## Count first
n_edges = 0
for i in range(n):
for ptr in range(indptr[i], indptr[i + 1]):
j = indices[ptr]
if j <= i:
continue ## upper triangle only (avoid duplicates)
if labels[i] == -1 or labels[j] == -1:
n_edges += 1
## Allocate and fill
edge_u = np.empty(n_edges, dtype=np.int64)
edge_v = np.empty(n_edges, dtype=np.int64)
edge_d = np.empty(n_edges, dtype=np.float64)
idx = 0
for i in range(n):
for ptr in range(indptr[i], indptr[i + 1]):
j = indices[ptr]
if j <= i:
continue
if labels[i] == -1 or labels[j] == -1:
edge_u[idx] = i
edge_v[idx] = j
edge_d[idx] = data[ptr]
idx += 1
## -- Sort edges by distance ascending --
sort_order = np.argsort(edge_d)
## -- Kruskal traversal --
for idx in range(len(sort_order)):
eidx = sort_order[idx]
d_val = edge_d[eidx]
## Stop if beyond cutoff
if d_val > d_cutoff:
break
u = edge_u[eidx]
v = edge_v[eidx]
## Find root of u with path compression
root_u = u
while parent[root_u] != root_u:
root_u = parent[root_u]
curr = u
while curr != root_u:
nxt = parent[curr]
parent[curr] = root_u
curr = nxt
## Find root of v with path compression
root_v = v
while parent[root_v] != root_v:
root_v = parent[root_v]
curr = v
while curr != root_v:
nxt = parent[curr]
parent[curr] = root_v
curr = nxt
## Already same component
if root_u == root_v:
continue
## Conflict check: bitmask AND
conflict = False
for ww in range(n_words):
if (comp_mask[root_u, ww] & comp_mask[root_v, ww]) != np.uint64(0):
conflict = True
break
if conflict:
continue
## Union by rank
if rank[root_u] > rank[root_v]:
new_root = root_u
old_root = root_v
elif rank[root_u] < rank[root_v]:
new_root = root_v
old_root = root_u
else:
new_root = root_u
old_root = root_v
rank[new_root] += 1
parent[old_root] = new_root
## Merge bitmasks
for ww in range(n_words):
comp_mask[new_root, ww] = (
comp_mask[new_root, ww] | comp_mask[old_root, ww]
)
## -- Extract labels from DSU --
## Find root for every node (with path compression)
for i in range(n):
root_i = i
while parent[root_i] != root_i:
root_i = parent[root_i]
curr = i
while curr != root_i:
nxt = parent[curr]
parent[curr] = root_i
curr = nxt
## Map: root → Phase 1 label (if component has one)
## Also count component sizes
root_to_label = np.full(n, -1, dtype=np.int64)
comp_size = np.zeros(n, dtype=np.int64)
for i in range(n):
root_i = parent[i]
comp_size[root_i] += 1
if labels[i] >= 0:
root_to_label[root_i] = labels[i]
## Assign new labels for noise-only components of size >= 2
next_label = max_label + 1 if max_label >= 0 else 0
for i in range(n):
if parent[i] == i and root_to_label[i] < 0 and comp_size[i] >= 2:
root_to_label[i] = next_label
next_label += 1
## Build output labels
new_labels = np.empty(n, dtype=np.int64)
for i in range(n):
new_labels[i] = root_to_label[parent[i]]
return new_labels
_noise_rescue_kruskal_kernel = kernel
return _noise_rescue_kruskal_kernel
[docs]
def noise_rescue_kruskal(
d_conj: scipy.sparse.csr_array,
labels: np.ndarray,
group_labels: np.ndarray,
n_groups: int,
d_cutoff: float,
) -> np.ndarray:
"""
Assign HDBSCAN noise points to nearby clusters (or nucleate new small
clusters) using a Kruskal-style sorted-edge traversal with DSU and
bitmask cannot-link constraints.
This is Phase 2 of a two-phase clustering strategy:
* **Phase 1**: HDBSCAN with ``min_samples > 1`` produces robust core
clusters but marks many ROIs as noise (``label == -1``).
* **Phase 2** (this function): Processes edges from the distance graph
where at least one endpoint is noise, sorted by distance. Merges
are accepted only if no session conflict arises (checked via
``uint64`` bitmask per DSU component). This can either assign noise
points to existing Phase 1 clusters or nucleate new clusters when
2+ noise points are mutual neighbors within ``d_cutoff``.
The algorithm uses the same DSU + bitmask pattern as
``fast_hdbscan``'s ``_kruskal_core_group_constrained``, but with the DSU
**pre-initialized** from Phase 1's pre-formed clusters.
Args:
d_conj (scipy.sparse.csr_array):
Sparse distance matrix (inter-session masked). Shape:
*(n_rois, n_rois)*. Must be symmetric with sorted indices.
labels (np.ndarray):
Phase 1 cluster labels. Shape: *(n_rois,)*. ``-1`` = noise.
group_labels (np.ndarray):
Session index per ROI (``int32``). Shape: *(n_rois,)*.
n_groups (int):
Number of distinct sessions (groups).
d_cutoff (float):
Maximum edge distance to accept. Edges with ``d > d_cutoff``
are ignored.
Returns:
(np.ndarray):
new_labels (np.ndarray):
Updated cluster labels. Shape: *(n_rois,)*. Noise points
that were rescued get their assigned cluster label; new
clusters of 2+ ex-noise points get fresh label IDs;
remaining singletons stay ``-1``.
"""
assert isinstance(d_conj, scipy.sparse.csr_array), (
f"Expected scipy.sparse.csr_array, got {type(d_conj)}"
)
assert d_conj.shape[0] == d_conj.shape[1], "Distance matrix must be square."
n = d_conj.shape[0]
assert labels.shape == (n,), f"labels shape {labels.shape} != ({n},)"
assert group_labels.shape == (n,), f"group_labels shape {group_labels.shape} != ({n},)"
## Ensure sorted indices for consistent edge iteration
d_sorted = d_conj
if not d_conj.has_sorted_indices:
d_sorted = d_conj.copy()
d_sorted.sort_indices()
kernel = _get_noise_rescue_kernel()
new_labels = kernel(
indptr=d_sorted.indptr.astype(np.int32),
indices=d_sorted.indices.astype(np.int32),
data=d_sorted.data.astype(np.float64),
labels=labels.astype(np.int64),
group_labels=group_labels.astype(np.int32),
n_groups=int(n_groups),
d_cutoff=float(d_cutoff),
)
return new_labels
[docs]
def attach_fully_connected_node(
d: object,
dist_fullyConnectedNode: Optional[float] = None,
n_nodes: int = 1,
) -> object:
"""
Appends a single node to a sparse distance graph that is weakly connected to all nodes.
Args:
d (object):
Sparse graph with multiple components.
Refer to scipy.sparse.csgraph.connected_components for details.
dist_fullyConnectedNode (Optional[float]):
Value used for the connection strength to all other nodes.
This value will be appended as elements in a new row and column
at the ends of the 'd' matrix. If ``None``, then the value will be
set to 1000 times the difference between the maximum and minimum
values in 'd'. (Default is ``None``)
n_nodes (int):
Number of nodes to append to the graph. (Default is *1*)
Returns:
(object):
d2 (object):
Sparse graph with only one component.
"""
if dist_fullyConnectedNode is None:
dist_fullyConnectedNode = (d.max() - d.min()) * 1000
d2 = d.copy()
d2 = scipy.sparse.vstack((d2, np.ones((n_nodes,d2.shape[1]), dtype=d.dtype)*dist_fullyConnectedNode))
d2 = scipy.sparse.hstack((d2, np.ones((d2.shape[0],n_nodes), dtype=d.dtype)*dist_fullyConnectedNode))
return d2.tocsr()
[docs]
def score_labels(
labels_test: np.ndarray,
labels_true: np.ndarray,
ignore_negOne: bool = False,
thresh_perfect: float = 0.9999999999,
) -> Dict[str, Union[float, Tuple[int, int]]]:
"""
Computes the score of the clustering by finding the best match using the
linear sum assignment problem. The score is bounded between 0 and 1. Note:
The score is not symmetric if the number of true and test labels are not the
same. I.e., switching ``labels_test`` and ``labels_true`` can lead to
different scores. This is because we are scoring how well each true set is
matched by an optimally assigned test set.
RH 2022
Args:
labels_test (np.ndarray):
Labels of the test clusters/sets. (shape: *(n,)*)
labels_true (np.ndarray):
Labels of the true clusters/sets. (shape: *(n,)*)
ignore_negOne (bool):
Whether to ignore ``-1`` values in the labels. If set to ``True``,
``-1`` values will be ignored in the computation. (Default is
``False``)
thresh_perfect (float):
Threshold for perfect match. Mostly used for numerical stability.
(Default is *0.9999999999*)
Returns:
(dict): dictionary containing:
score_weighted_partial (float):
Average correlation between the best matched sets of true and
test labels, weighted by the number of elements in each true
set.
score_weighted_perfect (float):
Fraction of perfect matches, weighted by the number of elements
in each true set.
score_unweighted_partial (float):
Average correlation between the best matched sets of true and
test labels.
score_unweighted_perfect (float):
Fraction of perfect matches.
adj_rand_score (float):
Adjusted Rand score of the labels.
adj_mutual_info_score (float):
Adjusted mutual info score of the labels. None if
``compute_mutual_info`` is ``False``.
ignore_negOne (bool):
Whether ``-1`` values were ignored in the labels.
idx_hungarian (Tuple[int, int]):
'Hungarian Indices'. Indices of the best matched sets.
"""
assert len(labels_test) == len(labels_true), 'RH ERROR: labels_test and labels_true must be the same length.'
labels_test = np.array(labels_test, dtype=int)
labels_true = np.array(labels_true, dtype=int)
## convert labels to boolean
uniques_test, uniques_true = np.unique(labels_test), np.unique(labels_true)
bool_test = np.stack([labels_test==l for l in uniques_test], axis=0).astype(np.float32)
bool_true = np.stack([labels_true==l for l in uniques_true], axis=0).astype(np.float32)
## Hungarian matching score
if ignore_negOne:
bool_test[uniques_test == -1, :] = 0.0
bool_true[uniques_true == -1, :] = 0.0
if bool_test.shape[0] < bool_true.shape[0]:
bool_test = np.concatenate((bool_test, np.zeros((bool_true.shape[0] - bool_test.shape[0], bool_true.shape[1]))))
## compute confusion / correlation matrix
with np.errstate(divide='ignore', invalid='ignore'):
cc = np.nan_to_num((bool_true @ bool_test.T) / (bool_true.sum(axis=1)[:, None]), nan=0.0, posinf=0.0, neginf=0.0) ## normalize by the number of elements in each set
## find hungarian assignment matching indices
hi = scipy.optimize.linear_sum_assignment(cost_matrix=cc, maximize=True)
## extract correlation scores of matches
cc_matched = cc[hi[0], hi[1]]
label_weights = bool_true.sum(axis=1)[hi[0]] ## reweighting vector is the number of elements in each true set
label_weights_norm = label_weights / label_weights.sum() ## normalize the weights
hungarian_match_score_weighted_partial = cc_matched @ label_weights_norm
hungarian_match_score_unweighted_partial = cc_matched.mean()
hungarian_match_score_weighted_perfect = (cc_matched > thresh_perfect).astype(float) @ label_weights_norm
hungarian_match_score_unweighted_perfect = (cc_matched > thresh_perfect).mean()
## SKLEARN METRICS
### First change all -1 values to a unique value that is not present in the labels
uniques = np.unique(np.concatenate((labels_true, labels_test)))
### Make a bunch of values greater than the maximum value in the labels
n_minusOne_true = np.sum(labels_true == -1)
n_minusOne_test = np.sum(labels_test == -1)
labels_true_sk, labels_test_sk = labels_true.copy(), labels_test.copy()
labels_true_sk[labels_true == -1] = uniques.max() + np.arange(1, n_minusOne_true + 1)
labels_test_sk[labels_test == -1] = uniques.max() + np.arange(n_minusOne_true + 1, n_minusOne_true + n_minusOne_test + 1)
## compute adjusted rand score
score_rand = sklearn.metrics.adjusted_rand_score(labels_true=labels_true_sk, labels_pred=labels_test_sk)
## compute fowlkes mallows score
score_fowlkes_mallows = sklearn.metrics.fowlkes_mallows_score(labels_true=labels_true_sk, labels_pred=labels_test_sk)
## compute adjusted mutual info score
score_mutual_info = sklearn.metrics.adjusted_mutual_info_score(labels_true=labels_true_sk, labels_pred=labels_test_sk)
## compute homogeneity, completeness, and v-measure
homogeneity, completeness, v_measure = sklearn.metrics.homogeneity_completeness_v_measure(labels_true=labels_true_sk, labels_pred=labels_test_sk, beta=1.0)
## compute pair confusion matrix
pair_confusion = sklearn.metrics.cluster.pair_confusion_matrix(labels_true=labels_true_sk, labels_pred=labels_test_sk).tolist()
if pair_confusion is not None:
p = np.array(pair_confusion)
TP, TN, FP, FN = p[1, 1], p[0, 0], p[0, 1], p[1, 0]
pc_accuracy = (TP + TN) / (TP + TN + FP + FN)
N_all = TN + FP
P_all = TP + FN
pc_precision = TP / (TP + FP) if (TP + FP) > 0 else 0
pc_recall = TP / (TP + FN) if (TP + FN) > 0 else 0
pc_f1 = (2 * pc_precision * pc_recall) / (pc_precision + pc_recall) if (pc_precision + pc_recall) > 0 else 0
pc_accuracy_norm = (TP/P_all + TN/N_all) / (TP/P_all + TN/N_all + FP/N_all + FN/P_all)
else:
pc_accuracy = None
pc_accuracy_norm = None
out = {
'hungarian_score_weighted_partial': float(hungarian_match_score_weighted_partial),
'hungarian_score_weighted_perfect': float(hungarian_match_score_weighted_perfect),
'hungarian_score_unweighted_partial': float(hungarian_match_score_unweighted_partial),
'hungarian_score_unweighted_perfect': float(hungarian_match_score_unweighted_perfect),
'adj_rand_score': float(score_rand),
'fowlkes_mallows_score': float(score_fowlkes_mallows),
'adj_mutual_info_score': float(score_mutual_info),
'homogeneity_score': float(homogeneity),
'completeness_score': float(completeness),
'v_measure_score': float(v_measure),
'pair_confusion_matrix': pair_confusion,
'pair_confusion_accuracy_score': float(pc_accuracy),
'pair_confusion_accuracy_norm_score': float(pc_accuracy_norm),
'pair_confusion_precision_score': float(pc_precision) if pc_precision is not None else None,
'pair_confusion_recall_score': float(pc_recall) if pc_recall is not None else None,
'pair_confusion_f1_score': float(pc_f1) if pc_f1 is not None else None,
'labels_test': labels_test.tolist(),
'ignore_negOne': ignore_negOne,
'idx_hungarian': hi,
}
return out
[docs]
def cluster_quality_metrics(
sim: Union[np.ndarray, scipy.sparse.csr_array],
labels: np.ndarray,
) -> Tuple:
"""
Computes the cluster quality metrics for a clustering solution including
intra-cluster mean, minimum, maximum similarity, and cluster silhouette
score.
RH 2023
Args:
sim (Union[np.ndarray, scipy.sparse.csr_array]):
Similarity matrix. (shape: *(n_roi, n_roi)*) It can be obtained
using `_, sConj, _,_,_,_ =
clusterer.make_conjunctive_similarity_matrix()`.
labels (np.ndarray):
Cluster labels. (shape: *(n_roi,)*)
Returns:
(tuple): tuple containing:
cs_intra_means (np.ndarray):
Intra-cluster mean similarity. (shape: *(n_clusters,)*)
cs_intra_mins (np.ndarray):
Intra-cluster minimum similarity. (shape: *(n_clusters,)*)
cs_intra_maxs (np.ndarray):
Intra-cluster maximum similarity. (shape: *(n_clusters,)*)
cs_sil (np.ndarray):
Cluster silhouette score. (shape: *(n_clusters,)*)
Describes intra_mean - inter_max_of_maxes
"""
import sparse
labels_unique, cs_mean, cs_max, cs_min = helpers.compute_cluster_similarity_matrices(sim, labels, verbose=True)
fn_sil_score = lambda intra, inter: (intra - inter) / np.maximum(intra, inter)
eye_inv = 1 - sparse.eye(cs_max.shape[0])
cs_intra_means = cs_mean.diagonal()
cs_inter_maxOfMaxs = (eye_inv * cs_max).max(0)
cs_sil = fn_sil_score(cs_intra_means, cs_inter_maxOfMaxs)
cs_intra_mins = cs_min.diagonal()
cs_intra_maxs = cs_max.diagonal()
return labels_unique, cs_intra_means, cs_intra_mins, cs_intra_maxs, cs_sil
[docs]
def make_label_variants(
labels: np.ndarray,
n_roi_bySession: np.ndarray,
) -> Tuple:
"""
Creates convenient variants of label arrays.
RH 2023
Args:
labels (np.ndarray):
Cluster integer labels. (shape: *(n_roi,)*)
n_roi_bySession (np.ndarray):
Number of ROIs in each session.
Returns:
(tuple): tuple containing:
labels_squeezed (np.ndarray):
Cluster labels squeezed into a continuous range starting from 0.
labels_bySession (List[np.ndarray]):
List of label arrays split by session.
labels_bool (scipy.sparse.csr_array):
Sparse boolean matrix representation of labels.
labels_bool_bySession (List[scipy.sparse.csr_array]):
List of sparse boolean matrix representations of labels split by
session.
labels_dict (Dict[int, np.ndarray]):
Dictionary mapping unique labels to their locations in the
labels array.
"""
import scipy.sparse
## assert that labels is a 1D np.array or list of numbers
if isinstance(labels, list):
labels = np.array(labels, dtype=np.int64)
elif isinstance(labels, np.ndarray):
labels = labels.astype(np.int64)
else:
raise TypeError('RH ERROR: labels must be a 1D np.array or list of numbers.')
assert labels.ndim == 1, 'RH ERROR: labels must be a 1D np.array or list of numbers.'
## assert that n_roi_bySession is a 1D np.array or list of numbers
if isinstance(n_roi_bySession, list):
n_roi_bySession = np.array(n_roi_bySession, dtype=np.int64)
elif isinstance(n_roi_bySession, np.ndarray):
n_roi_bySession = n_roi_bySession.astype(np.int64)
else:
raise TypeError('RH ERROR: n_roi_bySession must be a 1D np.array or list of numbers.')
assert n_roi_bySession.ndim == 1, 'RH ERROR: n_roi_bySession must be a 1D np.array or list of numbers.'
## assert that the number of labels adds up to the number of ROIs
n_roi_total = n_roi_bySession.sum()
n_roi_cumsum = np.concatenate([[0], n_roi_bySession.cumsum()])
assert labels.shape[0] == n_roi_bySession.sum(), 'RH ERROR: the number of labels must add up to the number of ROIs.'
## make session_bool
session_bool = util.make_session_bool(n_roi_bySession)
## make variants
labels_squeezed = helpers.squeeze_integers(labels)
labels_bySession = [labels_squeezed[idx] for idx in session_bool.T]
labels_bool = scipy.sparse.vstack([scipy.sparse.csr_array(labels_squeezed==u) for u in np.sort(np.unique(labels_squeezed))]).T.tocsr()
labels_bool_bySession = [labels_bool[idx] for idx in session_bool.T]
labels_dict = {u: np.where(labels_squeezed==u)[0] for u in np.unique(labels_squeezed)}
## testing
assert np.allclose(np.concatenate(labels_bySession), labels_squeezed)
assert np.allclose(labels_bool.nonzero()[1] - 1, labels_squeezed)
assert np.all([np.allclose(np.where(labels_squeezed==u)[0], ldu) for u, ldu in labels_dict.items()])
## Convert everything to native python types for JSON compatibility
labels_squeezed = util.JSON_List([int(u) for u in labels_squeezed])
labels_bySession = util.JSON_List([[int(u) for u in l] for l in labels_bySession])
labels_dict = util.JSON_Dict({str(k): [int(v_i) for v_i in v] for k, v in labels_dict.items()}) ## Make keys strings for JSON compatibility
return labels_squeezed, labels_bySession, labels_bool, labels_bool_bySession, labels_dict
[docs]
def plot_quality_metrics(quality_metrics: dict, labels: Union[np.ndarray, list], n_sessions: int) -> None:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15,7))
axs[0,0].hist(quality_metrics['cluster_silhouette'], 50);
axs[0,0].set_xlabel('cluster_silhouette');
axs[0,0].set_ylabel('cluster counts');
axs[0,1].hist(quality_metrics['cluster_intra_means'], 50);
axs[0,1].set_xlabel('cluster_intra_means');
axs[0,1].set_ylabel('cluster counts');
axs[1,0].hist(quality_metrics['sample_silhouette'], 50);
axs[1,0].set_xlabel('sample_silhouette score');
axs[1,0].set_ylabel('roi sample counts');
u, c = np.unique((v:=np.array(labels))[v!=-1], return_counts=True)
n_sesh = np.bincount(c)
axs[1,1].bar(np.arange(len(n_sesh)), n_sesh);
axs[1,1].set_xlabel('n_sessions')
axs[1,1].set_ylabel('cluster counts');
# Make the title include the number of excluded (label==-1) ROIs
fig.suptitle(f'Quality metrics n_excluded: {np.sum(labels==-1)}, n_included: {np.sum(labels!=-1)}, n_total: {len(labels)}, n_clusters: {len(np.unique(labels[labels!=-1]))}, n_sessions: {n_sessions}')
return fig, axs