from pathlib import Path
import warnings
import copy
from typing import Dict, Any, Optional, Union, List, Tuple, Callable, Iterable, Iterator, Type
import datetime
import collections
import importlib
import numpy as np
import scipy.sparse
from tqdm.auto import tqdm
import torch
import richfile as rf
from . import helpers
## Floor value for sparse normalization denominators (row-max, row-sum, etc.).
## Values below this threshold are treated as zero for normalization purposes,
## preventing amplification of sub-noise-floor signal. This avoids both:
## (1) division-by-zero producing inf, and (2) amplifying numerical artifacts
## (e.g. 1e-15 values) to full scale during max-normalization of spatial footprints.
SPARSE_NORMALIZATION_FLOOR = 1e-12
[docs]
def get_roicat_version() -> str:
"""
Retrieves the version of the roicat package.
Returns:
(str):
version (str):
The version of the roicat package.
"""
return importlib.metadata.version('roicat')
[docs]
def get_default_parameters(
pipeline='tracking',
path_defaults=None
):
"""
This function returns a dictionary of parameters that can be used to run
different pipelines. RH 2023
Args:
pipeline (str):
The name of the pipeline to use. Options: \n
* 'tracking': Tracking pipeline. \n
* 'classification_inference': Classification inference pipeline
(TODO). \n
* 'classification_training': Classification training pipeline
(TODO). \n
* 'model_training': Model training pipeline (TODO). \n
path_defaults (str):
A path to a yaml file containing a parameters dictionary. The
parameters from the file will be loaded as is. If None, the default
parameters will be used.
Returns:
(dict):
params (dict):
A dictionary containing the default parameters.
"""
if path_defaults is not None:
defaults = helpers.yaml_load(path_defaults)
else:
defaults = {
'general': {
'use_GPU': True,
'verbose': True,
'random_seed': None,
},
'data_loading': {
'data_kind': 'data_suite2p', ## Can be 'suite2p', 'roiextractors', or 'roicat'. See documentation and/or notebook on custom data loading for more details.
'dir_outer': None, ## directory where directories containing below 'pathSuffixTo...' are
'reMatch_in_path': None, ## regular expression string that must be found within the parent path of any discovered files. See helpers.find_paths
'common': {
'um_per_pixel': 1.0, ## Number of microns per pixel for the imaging dataset. Doesn't need to be exact. Used for resizing the ROIs. Check the images of the resized ROIs to tweak.
'centroid_method': 'centerOfMass', ## Can be 'centerOfMass' or 'median'.
'out_height_width': [36,36], ## Height and width of the small ROI_images. Should generally be tuned slightly bigger than the largest ROIs. Leave if uncertain or if ROIs are small enough to fit in the default size.
},
'data_suite2p': {
'new_or_old_suite2p': 'new', ## Can be 'new' or 'old'. 'new' is for the Python version of Suite2p, 'old' is for the MATLAB version.
'type_meanImg': 'meanImgE', ## Can be 'meanImg' or 'meanImgE'. 'meanImg' is the mean image of the dataset, 'meanImgE' is the mean image of the dataset after contrast enhancement.
},
'data_roicat': {
'filename_search': r'data_roicat.richfile', ## Name stem of the single file (as a regex search string) in 'dir_outer' to look for. The files should be saved Data_roicat object.
},
},
'alignment': {
'initialization': {
'use_match_search': True, ## Whether or not to use our match search algorithm to initialize the alignment.
'all_to_all': False, ## Force the use of our match search algorithm for all-pairs matching. Much slower (False: O(N) vs. True: O(N^2)), but more accurate.
'radius_in': 4.0, ## Value in micrometers used to define the maximum shift/offset between two images that are considered to be aligned. Larger means more lenient alignment.
'radius_out': 20.0, ## Value in micrometers used to define the minimum shift/offset between two images that are considered to be misaligned.
'z_threshold': 4.0, ## Z-score required to define two images as aligned. Larger values results in more stringent alignment requirements.
},
'augment': {
'normalize_FOV_intensities': True, ## Whether or not to normalize the FOV_images to the max value across all FOV images.
'roi_FOV_mixing_factor': 0.5, ## default: 0.5. Fraction of the max intensity projection of ROIs that is added to the FOV image. 0.0 means only the FOV_images, larger values mean more of the ROIs are added.
'use_CLAHE': True, ## Whether or not to use 'Contrast Limited Adaptive Histogram Equalization'. Useful if params['importing']['type_meanImg'] is not a contrast enhanced image (like 'meanImgE' in Suite2p)
'CLAHE_grid_block_size': 10, ## Size of the block size for the grid for CLAHE. Smaller values means more local contrast enhancement.
'CLAHE_clipLimit': 1.0, ## Clipping limit for CLAHE. Higher values mean more contrast.
'CLAHE_normalize': True, ## Whether or not to normalize the CLAHE image.
},
'fit_geometric': {
'template': 0.5, ## Which session to use as a registration template. If input is float (ie 0.0, 0.5, 1.0, etc.), then it is the fractional position of the session to use; if input is int (ie 1, 2, 3), then it is the index of the session to use (0-indexed)
'template_method': 'image', ## Can be 'sequential' or 'image'. If 'sequential', then the template is the FOV_image of the previous session. If 'image', then the template is the FOV_image of the session specified by 'template'.
'mask_borders': [0, 0, 0, 0], ## Number of pixels to mask from the borders of the FOV_image. (top, bottom, left, right). Useful for removing artifacts from the edges of the FOV_image.
'method': 'RoMa', ## Accuracy order (best to worst): RoMa (by far, but slow without a GPU), LoFTR, DISK_LightGlue, ECC_cv2, (the following are not recommended) SIFT, ORB
'kwargs_method': {
'RoMa': {
'model_type': 'outdoor',
'n_points': 10000, ## Higher values mean more points are used for the registration. Useful for larger FOV_images. Larger means slower.
'batch_size': 1000,
},
'DISK_LightGlue': {
'num_features': 3000, ## Number of features to extract and match. I've seen best results around 2048 despite higher values typically being better.
'threshold_confidence': 0.2, ## Higher values means fewer but better matches.
},
'LoFTR': {
'model_type': 'indoor_new',
'threshold_confidence': 0.2, ## Higher values means fewer but better matches.
},
'ECC_cv2': {
'mode_transform': 'euclidean', ## Must be one of {'translation', 'affine', 'euclidean', 'homography'}. See cv2 documentation on findTransformECC for more details.
'n_iter': 200,
'termination_eps': 1e-09, ## Termination criteria for the registration algorithm. See documentation for more details.
'gaussFiltSize': 1, ## Size of the gaussian filter used to smooth the FOV_image before registration. Larger values mean more smoothing.
'auto_fix_gaussFilt_step': 10, ## If the registration fails, then the gaussian filter size is reduced by this amount and the registration is tried again.
},
'PhaseCorrelation': {
'bandpass_freqs': [1, 30],
'order': 5,
},
'SIFT': {
'nfeatures': 10000,
'contrastThreshold': 0.04,
'edgeThreshold': 10,
'sigma': 1.6,
},
'ORB': {
'nfeatures': 1000,
'scaleFactor': 1.2,
'nlevels': 8,
'edgeThreshold': 31,
'firstLevel': 0,
'WTA_K': 2,
'scoreType': 0,
'patchSize': 31,
'fastThreshold': 20,
},
'NullRegistration': {}, ## No registration, no warping.
},
'constraint': 'affine', ## Must be one of {'rigid', 'euclidean', 'similarity', 'affine', 'homography'}. Choose constraint based on expected changes in images; use the simplest constraint that is applicable.
'kwargs_RANSAC': { ## Parameters related to the RANSAC algorithm used for point/descriptor based registration methods.
'inl_thresh': 3.0, ## Threshold for the inliers. Larger values mean more points are considered inliers.
'max_iter': 100, ## Maximum number of iterations for the RANSAC algorithm.
'confidence': 0.99, ## Confidence level for the RANSAC algorithm. Larger values mean more points are considered inliers.
},
},
'fit_nonrigid': {
'template': 0.5, ## Which session to use as a registration template. If input is float (ie 0.0, 0.5, 1.0, etc.), then it is the fractional position of the session to use; if input is int (ie 1, 2, 3), then it is the index of the session to use (0-indexed)
'template_method': 'image', ## Can be 'sequential' or 'image'. If 'sequential', then the template is the FOV_image of the previous session. If 'image', then the template is the FOV_image of the session specified by 'template'.
'method': 'DeepFlow',
'kwargs_method': {
'RoMa': {
'model_type': 'outdoor',
},
'DeepFlow': {},
'OpticalFlowFarneback': {
'pyr_scale': 0.7,
'levels': 5,
'winsize': 128,
'iterations': 15,
'poly_n': 5,
'poly_sigma': 1.5,
},
'NullRegistration': {},
},
},
'transform_ROIs': {
'normalize': True, ## If True, normalize the spatial footprints to have a sum of 1.
},
},
'blurring': {
'kernel_halfWidth': 2.0, ## Half-width of the cosine kernel used for blurring. Set value based on how much you think the ROIs move from session to session.
},
'ROInet': {
'network': {
'download_method': 'check_local_first', ## Check to see if a model has already been downloaded to the location (will skip if hash matches)
'download_url': 'https://osf.io/x3fd2/download', ## URL of the model
'download_hash': '7a5fb8ad94b110037785a46b9463ea94', ## Hash of the model file
'forward_pass_version': 'latent', ## How the data is passed through the network
},
'dataloader': {
'jit_script_transforms': False, ## (advanced) Whether or not to use torch.jit.script to speed things up
'batchSize_dataloader': 8, ## (advanced) PyTorch dataloader batch_size
'pinMemory_dataloader': True, ## (advanced) PyTorch dataloader pin_memory
'numWorkers_dataloader': -1, ## (advanced) PyTorch dataloader num_workers. -1 is all cores.
'persistentWorkers_dataloader': True, ## (advanced) PyTorch dataloader persistent_workers
'prefetchFactor_dataloader': 2, ## (advanced) PyTorch dataloader prefetch_factor
},
},
'SWT': {
'kwargs_Scattering2D': {'J': 2, 'L': 12}, ## 'J' is the number of convolutional layers. 'L' is the number of wavelet angles.
'batch_size': 100, ## Batch size for each iteration (smaller is less memory but slower)
},
'similarity_graph': {
'sparsification': {
'n_workers': -1, ## Number of CPU cores to use. -1 for all.
'block_height': 128, ## size of a block
'block_width': 128, ## size of a block
'algorithm_nearestNeigbors_spatialFootprints': 'brute', ## algorithm used to find the pairwise similarity for s_sf. ('brute' is slow but exact. See docs for others.)
},
'compute_similarity': {
'spatialFootprint_maskPower': 1.0, ## An exponent to raise the spatial footprints to to care more or less about bright pixels
},
'normalization': {
'k_max': 100, ## Maximum number of nearest neighbors * n_sessions to consider for the normalizing distribution
'k_min': 10, ## Minimum number of nearest neighbors * n_sessions to consider for the normalizing distribution
'algo_NN': 'kd_tree', ## Nearest neighbors algorithm to use
},
},
'clustering': {
'mixing_method': 'automatic', ## 'automatic' (NB calibration + freeze-sigmoid DE) or 'manual'
'parameters_automatic_mixing': {
'n_bins': None, ## Number of bins for histograms. None = heuristic.
'smoothing_window_bins': None, ## Smoothing window for distributions. None = heuristic.
'subsample_pairs': None, ## Subsample this many pairs for speedup. None = use all.
'bounds_findParameters': {
'power_nn': [0.0, 2.], ## Bounds for the exponent applied to s_nn
'power_swt': [0.0, 2.], ## Bounds for the exponent applied to s_swt
'p_norm': [-5, -0.1], ## Bounds for the p-norm (Minkowski) mixing parameter
},
'de_kwargs': {
'maxiter': 100, ## Max DE generations
'tol': 1e-6, ## Convergence tolerance
'popsize': 15, ## Population size multiplier
'mutation': [0.5, 1.5], ## DE mutation range
'recombination': 0.7, ## DE crossover probability
'polish': True, ## L-BFGS-B polish after DE
},
},
'parameters_manual_mixing': {
'power_sf': 1.0, ## s_sf**power_sf (Higher values means clustering is more sensitive to spatial overlap of ROIs)
'power_nn': 0.5, ## s_nn**power_nn (Higher values means clustering is more sensitive to visual similarity of ROIs)
'power_swt': 0.5, ## s_swt**power_swt (Higher values means clustering is more sensitive to visual similarity of ROIs)
'p_norm': -1.0, ## norm([s_sf, s_nn, s_swt], p=p_norm) (Higher values means clustering requires all similarity metrics to be high)
'sig_sf_kwargs': None, ## Sigmoid parameters for s_sf (mu is the center, b is the slope)
'sig_nn_kwargs': {'mu': 0.5, 'b': 1.0}, ## Sigmoid parameters for s_nn (mu is the center, b is the slope)
'sig_swt_kwargs': {'mu': 0.5, 'b': 1.0}, ## Sigmoid parameters for s_swt (mu is the center, b is the slope)
},
'pruning': {
'd_cutoff': None, ## Optionally manually specify a distance cutoff
'stringency': 1.0, ## How to scale the d_cuttoff. This is a scalaing factor. Smaller numbers result in more pruning.
'convert_to_probability': False, ## Whether or not to convert the similarity matrix and distance matrix to a probability matrix
},
'cluster_method': {
'method': 'automatic', ## 'automatic', 'hdbscan', or 'sequential_hungarian'. 'automatic': selects which clustering algorithm to use (generally if n_sessions >=8 then hdbscan, else sequential_hungarian)
'n_sessions_switch': 6, ## Number of sessions to switch from sequential_hungarian to hdbscan
},
'hdbscan': {
'min_cluster_size': 2, ## Minimum number of ROIs that can be considered a 'cluster'
'max_cluster_size': None, ## Maximum cluster size. None defaults to n_sessions (one ROI per session).
'min_samples': None, ## Number of neighbors for core-point density. None defaults to min_cluster_size. Lower values reduce noise points.
'cluster_selection_method': 'eom', ## (advanced) Method of cluster selection for HDBSCAN (see hdbscan documentation)
'cluster_selection_persistence': 0.0, ## (advanced) Minimum stability for a cluster to survive. Higher values = fewer, more stable clusters.
'd_clusterMerge': None, ## Distance below which clusters are merged. None defaults to d_cutoff (pruning threshold).
'rescue_noise': True, ## Post-HDBSCAN noise rescue: assign noise ROIs to nearby clusters via Kruskal DSU with cannot-link constraints.
'n_iter_violationCorrection': 6, ## Number of times to redo clustering sweep after removing violations
'split_intraSession_clusters': True, ## Whether or not to split clusters with ROIs from the same session
'alpha': 0.999, ## (advanced) Scalar applied to distance matrix in HDBSCAN (see hdbscan documentation)
'discard_failed_pruning': True, ## (advanced) Whether or not to set all ROIs that could be separated from clusters with ROIs from the same sessions to label=-1
'n_steps_clusterSplit': 100, ## (advanced) How finely to step through distances to remove violations
},
'sequential_hungarian': {
'thresh_cost': 0.6, ## Threshold for the cost matrix. Lower numbers result in more clusters.
},
},
'results_saving': {
'dir_save': None, ## Directory to save results to. If None, will not save.
'prefix_name_save': str(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")), ## Prefix to append to the saved files
'richfile_backend': 'zip', ## Backend for saving richfile data. Options: 'directory', 'sqlar', 'zip' (default), 'tar'. Archive backends produce a single file instead of a directory tree.
'gif_frame_rate': 10.0 ## Frame rate for any GIFs saved
},
}
## Pipeline specific parameters
### prepare the different modules for each pipeline
keys_pipeline = {
'tracking': [
'general',
'data_loading',
'alignment',
'blurring',
'ROInet',
'SWT',
'similarity_graph',
'clustering',
'results_saving',
],
'classification_inference': [
'general',
'data_loading',
'ROInet',
'results_saving',
],
'classification_training': [
'general',
'data_loading',
'ROInet',
'results_saving',
],
}
### prepare pipeline specific parameters
if pipeline == 'tracking':
out = copy.deepcopy({key: defaults[key] for key in keys_pipeline[pipeline]})
out['ROInet']['network'] = {
'download_method': 'check_local_first', ## Check to see if a model has already been downloaded to the location (will skip if hash matches)
'download_url': 'https://osf.io/x3fd2/download', ## URL of the model
'download_hash': '7a5fb8ad94b110037785a46b9463ea94', ## Hash of the model file
'forward_pass_version': 'latent', ## How the data is passed through the network
}
elif pipeline == 'classification_inference':
out = copy.deepcopy({key: defaults[key] for key in keys_pipeline[pipeline]})
out['ROInet']['network'] = {
'download_method': 'check_local_first', ## Check to see if a model has already been downloaded to the location (will skip if hash matches)
'download_url': 'https://osf.io/c8m3b/download', ## URL of the model
'download_hash': '357a8d9b630ec79f3e015d0056a4c2d5', ## Hash of the model file
'forward_pass_version': 'head', ## How the data is passed through the network
}
elif pipeline == 'classification_training':
out = copy.deepcopy({key: defaults[key] for key in keys_pipeline[pipeline]})
out['ROInet']['network'] = {
'download_method': 'check_local_first', ## Check to see if a model has already been downloaded to the location (will skip if hash matches)
'download_url': 'https://osf.io/c8m3b/download', ## URL of the model
'download_hash': '357a8d9b630ec79f3e015d0056a4c2d5', ## Hash of the model file
'forward_pass_version': 'head', ## How the data is passed through the network
}
else:
raise NotImplementedError(f'pipeline={pipeline}, which is not implemented or not recognized. Should be one of: {list(keys_pipeline.keys())}')
return out
[docs]
def system_info(verbose: bool = False,) -> Dict:
"""
Checks and prints the versions of various important software packages.
RH 2022
Args:
verbose (bool):
Whether to print the software versions.
(Default is ``False``)
Returns:
(Dict):
versions (Dict):
Dictionary containing the versions of various software packages.
"""
## Operating system and version
import platform
def try_fns(fn):
try:
return fn()
except:
return None
fns = {key: val for key, val in platform.__dict__.items() if (callable(val) and key[0] != '_')}
operating_system = {key: try_fns(val) for key, val in fns.items() if (callable(val) and key[0] != '_')}
print(f'== Operating System ==: {operating_system["uname"]}') if verbose else None
## CPU info
try:
import cpuinfo
import multiprocessing as mp
cpu_info_raw = cpuinfo.get_cpu_info()
cpu_n_cores = mp.cpu_count()
cpu_brand = cpu_info_raw.get('brand_raw', 'Unknown')
cpu_info = {'n_cores': cpu_n_cores, 'brand': cpu_brand}
if 'flags' in cpu_info_raw:
cpu_info['flags'] = 'omitted'
except Exception as e:
warnings.warn(f'RH WARNING: unable to get cpu info. Got error: {e}')
cpu_info = 'ROICaT Error: Failed to get'
print(f'== CPU Info ==: {cpu_info}') if verbose else None
## RAM
import psutil
ram = psutil.virtual_memory()
print(f'== RAM ==: {ram}') if verbose else None
## User
import getpass
user = getpass.getuser()
## GPU
try:
gpu_info = helpers.list_available_devices()
except Exception as e:
warnings.warn(f'RH WARNING: unable to get gpu info. Got error: {e}')
gpu_info = 'ROICaT Error: Failed to get'
print(f'== GPU Info ==: {gpu_info}') if verbose else None
## Conda Environment
import os
if 'CONDA_DEFAULT_ENV' not in os.environ:
conda_env = 'None'
else:
conda_env = os.environ['CONDA_DEFAULT_ENV']
print(f'== Conda Environment ==: {conda_env}') if verbose else None
## Python
import sys
python_version = sys.version.split(' ')[0]
print(f'== Python Version ==: {python_version}') if verbose else None
## GCC
import subprocess
try:
gcc_version = subprocess.check_output(['gcc', '--version']).decode('utf-8').split('\n')[0].split(' ')[-1]
except Exception as e:
warnings.warn(f'RH WARNING: unable to get gcc version. Got error: {e}')
gcc_version = 'Faled to get'
print(f'== GCC Version ==: {gcc_version}') if verbose else None
## PyTorch
import torch
torch_version = str(torch.__version__)
print(f'== PyTorch Version ==: {torch_version}') if verbose else None
## CUDA
if torch.cuda.is_available():
cuda_version = torch.version.cuda
cudnn_version = torch.backends.cudnn.version()
torch_devices = [f'device {i}: Name={torch.cuda.get_device_name(i)}, Memory={torch.cuda.get_device_properties(i).total_memory / 1e9} GB' for i in range(torch.cuda.device_count())]
print(f"== CUDA Version ==: {cuda_version}, CUDNN Version: {cudnn_version}, Number of Devices: {torch.cuda.device_count()}, Devices: {torch_devices}, ") if verbose else None
else:
cuda_version = None
cudnn_version = None
torch_devices = None
print('== CUDA is not available ==') if verbose else None
## all packages in environment
import importlib.metadata
pkgs_dict = {dist.metadata['Name'].lower(): dist.version for dist in importlib.metadata.distributions()}
## roicat
import time
roicat_version = importlib.metadata.version("roicat")
roicat_fileDate = time.ctime(os.path.getctime(importlib.metadata.distribution("roicat").locate_file('')))
roicat_stuff = {'version': roicat_version, 'date_installed': roicat_fileDate}
print(f'== ROICaT Version ==: {roicat_version}') if verbose else None
print(f'== ROICaT date installed ==: {roicat_fileDate}') if verbose else None
## get datetime
from datetime import datetime
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
versions = {
'datetime': dt,
'roicat': roicat_stuff,
'operating_system': operating_system,
'cpu_info': cpu_info, ## This is the slow one.
'user': user,
'ram': ram,
'gpu_info': gpu_info,
'conda_env': conda_env,
'python': python_version,
'gcc': gcc_version,
'torch': torch_version,
'cuda': cuda_version,
'cudnn': cudnn_version,
'torch_devices': torch_devices,
'pkgs': pkgs_dict,
}
def conv_str(obj):
if isinstance(obj, (dict, collections.OrderedDict)):
return {key: conv_str(val) for key, val in obj.items()}
elif isinstance(obj, (list, tuple, set, frozenset)):
return [conv_str(val) for val in obj]
elif isinstance(obj, (int, float, bool, type(None))):
return obj
else:
return str(obj)
versions = conv_str(versions)
return versions
[docs]
def set_random_seed(seed=None, deterministic=False):
"""
Set random seed for reproducibility.
RH 2023
Args:
seed (int, optional):
Random seed.
If None, a random seed (spanning int32 integer range) is generated.
deterministic (bool, optional):
Whether to make packages deterministic.
Returns:
(int):
seed (int):
Random seed.
"""
### random seed (note that optuna requires a random seed to be set within the pipeline)
import numpy as np
seed = int(np.random.randint(0, 2**31 - 1, dtype=np.uint32)) if seed is None else seed
np.random.seed(seed)
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
import random
random.seed(seed)
import cv2
cv2.setRNGSeed(seed)
## Make torch deterministic
torch.use_deterministic_algorithms(deterministic)
## Make cudnn deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
return seed
[docs]
class ROICaT_Module:
"""
Super class for ROICaT modules.
RH 2023
Attributes:
_system_info (object):
System information.
"""
def __init__(self) -> None:
"""
Initializes the ROICaT_Module class by gathering system information.
"""
self._system_info = system_info()
self.params = {}
pass
@property
def serializable_dict(self) -> Dict[str, Any]:
"""
Returns a serializable dictionary that can be saved to disk. This method
goes through all items in self.__dict__ and checks if they are
serializable. If they are, add them to a dictionary to be returned.
Returns:
(Dict[str, Any]):
serializable_dict (Dict[str, Any]):
Dictionary containing serializable items.
"""
from functools import partial
## Go through all items in self.__dict__ and check if they are serializable.
### If they are, add them to a dictionary to be returned.
import pickle
## Define a list of libraries and classes that are allowed to be serialized.
allowed_libraries = [
'roicat',
'builtins',
'collections',
'datetime',
'itertools',
'math',
'numbers',
'os',
'pathlib',
'string',
'time',
'numpy',
'scipy',
'sklearn',
]
def is_library_allowed(obj):
try:
try:
module_name = obj.__module__.split('.')[0]
except:
success = False
try:
module_name = obj.__class__.__module__.split('.')[0]
except:
success = False
except:
success = False
else:
## Check if the module_name is in the allowed_libraries list.
if module_name in allowed_libraries:
success = True
else:
success = False
return success
def make_serializable_dict(obj, depth=0, max_depth=100, name=None):
"""
Recursively go through all items in self.__dict__ and check if they are serializable.
"""
# print(name)
msd_partial = partial(make_serializable_dict, depth=depth+1, max_depth=max_depth)
if depth > max_depth:
raise Exception(f'RH ERROR: max_depth of {max_depth} reached with object: {obj}')
serializable_dict = {}
if hasattr(obj, '__dict__') and is_library_allowed(obj):
for key, val in obj.__dict__.items():
try:
serializable_dict[key] = msd_partial(val, name=key)
except:
pass
elif isinstance(obj, (list, tuple, set, frozenset)):
serializable_dict = [msd_partial(v, name=f'{name}_{ii}') for ii,v in enumerate(obj)]
elif isinstance(obj, dict):
serializable_dict = {k: msd_partial(v, name=f'{name}_{k}') for k,v in obj.items()}
else:
try:
assert is_library_allowed(obj), f'RH ERROR: object {obj} is not serializable'
pickle.dumps(obj)
except:
return {'__repr__': repr(obj)} if hasattr(obj, '__repr__') else {'__str__': str(obj)} if hasattr(obj, '__str__') else None
serializable_dict = obj
return serializable_dict
serializable_dict = make_serializable_dict(self, depth=0, max_depth=100, name='self')
return serializable_dict
# def save(
# self,
# path_save: Union[str, Path],
# save_as_serializable_dict: bool = False,
# allow_overwrite: bool = False,
# ) -> None:
# """
# Saves Data_roicat object to pickle file.
# Args:
# path_save (Union[str, pathlib.Path]):
# Path to save pickle file.
# save_as_serializable_dict (bool):
# An archival-type format that is easy to load data from, but typically
# cannot be used to re-instantiate the object. If ``True``, save the object
# as a serializable dictionary. If ``False``, save the object as a Data_roicat
# object. (Default is ``False``)
# allow_overwrite (bool):
# If ``True``, allow overwriting of existing file. (Default is ``False``)
# """
# from pathlib import Path
# ## Check if file already exists
# if not allow_overwrite:
# assert not Path(path_save).exists(), f"RH ERROR: File already exists: {path_save}. Set allow_overwrite=True to overwrite."
# helpers.pickle_save(
# obj=self.serializable_dict if save_as_serializable_dict else self,
# filepath=path_save,
# mkdir=True,
# allow_overwrite=allow_overwrite,
# )
# print(f"Saved Data_roicat as a pickled object to {path_save}.") if self._verbose else None
# def load(
# self,
# path_load: Union[str, Path],
# ) -> None:
# """
# Loads attributes from a Data_roicat object from a pickle file.
# Args:
# path_load (Union[str, Path]):
# Path to the pickle file.
# Note:
# After calling this method, the attributes of this object are updated with those
# loaded from the pickle file. If an object in the pickle file is a dictionary,
# the object's attributes are set directly from the dictionary. Otherwise, if
# the object in the pickle file has a 'import_from_dict' method, it is used
# to load attributes. If it does not, the attributes are directly loaded from
# the object's `__dict__` attribute.
# Example:
# .. highlight:: python
# .. code-block:: python
# obj = Data_roicat()
# obj.load('/path/to/pickle/file')
# """
# from pathlib import Path
# assert Path(path_load).exists(), f"RH ERROR: File does not exist: {path_load}."
# obj = helpers.pickle_load(path_load)
# assert isinstance(obj, (type(self), dict)), f"RH ERROR: Loaded object is not a Data_roicat object or dictionary. Loaded object is of type {type(obj)}."
# if isinstance(obj, dict):
# ## Set attributes from dict
# ### If the subclass has a load_from_dict method, use that.
# if hasattr(self, 'import_from_dict'):
# self.import_from_dict(obj)
# else:
# for key, val in obj.items():
# setattr(self, key, val)
# else:
# ## Set attributes from object
# for key, val in obj.__dict__.items():
# setattr(self, key, val)
# print(f"Loaded Data_roicat object from {path_load}.") if self._verbose else None
def _locals_to_params(
self,
locals_dict: Dict[str, Any],
keys: List[str],
) -> None:
"""
Returns a dictionary of the local variables with the specified keys.
Args:
locals_dict (Dict[str, Any]):
Dictionary of local variables.
keys (List[str]):
List of keys to extract from the local variables.
"""
def safe_getitem(d, key):
try:
return d[key]
except KeyError:
warnings.warn(f'RH WARNING: key={key} not found in locals_dict. Skipping.')
return {key: safe_getitem(locals_dict, key) for key in keys}
[docs]
class RichFile_ROICaT(rf.RichFile):
"""
RichFile subclass with ROICaT-specific type registrations (numpy arrays,
scipy sparse matrices, torch tensors, optuna studies, pandas DataFrames,
etc.).
Args:
path (Optional[Union[str, Path]]):
Path to save/load the richfile.
check (Optional[bool]):
Whether to perform validation checks.
safe_save (Optional[bool]):
Whether to use atomic save with temporary file.
backend (Optional[str]):
Storage backend. One of:
* ``'auto'``: auto-detect from existing path, or default to
``'directory'`` for new saves.
* ``'directory'``: classic richfile directory tree.
* ``'sqlar'``: single-file SQLite archive (``.sqlar``).
* ``'zip'``: single-file ZIP archive (``.zip``, stored/no
compression).
* ``'tar'``: single-file plain TAR archive (``.tar``).
"""
def __init__(
self,
path: Optional[Union[str, Path]] = None,
check: Optional[bool] = True,
safe_save: Optional[bool] = True,
backend: Optional[str] = "auto",
):
super().__init__(path=path, check=check, safe_save=safe_save, backend=backend)
## NUMPY ARRAY
import numpy as np
def save_npy_array(
obj: np.ndarray,
path: Union[str, Path],
**kwargs,
) -> None:
"""
Saves a NumPy array to the given path.
"""
np.save(path, obj, **kwargs)
def load_npy_array(
path: Union[str, Path],
**kwargs,
) -> np.ndarray:
"""
Loads an array from the given path.
"""
return np.load(path, **kwargs)
## SCIPY SPARSE MATRIX
import scipy.sparse
def save_sparse_array(
obj: Union[scipy.sparse.spmatrix, scipy.sparse.sparray],
path: Union[str, Path],
**kwargs,
) -> None:
"""
Saves a SciPy sparse matrix/array to the given path.
"""
scipy.sparse.save_npz(path, obj, **kwargs)
def load_sparse_array(
path: Union[str, Path],
**kwargs,
) -> scipy.sparse.sparray:
"""
Loads a sparse array from the given path.
"""
return scipy.sparse.load_npz(path, **kwargs)
## JSON DICT
import collections
import json
def save_json_dict(
obj: collections.UserDict,
path: Union[str, Path],
**kwargs,
) -> None:
"""
Saves a dictionary to the given path.
"""
with open(path, 'w') as f:
json.dump(dict(obj), f, **kwargs)
def load_json_dict(
path: Union[str, Path],
**kwargs,
) -> collections.UserDict:
"""
Loads a dictionary from the given path.
"""
with open(path, 'r') as f:
return JSON_Dict(json.load(f, **kwargs))
## JSON LIST
def save_json_list(
obj: collections.UserList,
path: Union[str, Path],
**kwargs,
) -> None:
"""
Saves a list to the given path.
"""
with open(path, 'w') as f:
json.dump(list(obj), f, **kwargs)
def load_json_list(
path: Union[str, Path],
**kwargs,
) -> collections.UserList:
"""
Loads a list from the given path.
"""
with open(path, 'r') as f:
return JSON_List(json.load(f, **kwargs))
## OPTUNA STUDY
import optuna
import pickle
## load and save functions for optuna study
def save_optuna_study(
obj: optuna.study.Study,
path: Union[str, Path],
**kwargs,
) -> None:
"""
Saves an Optuna study to the given path.
"""
with open(path, 'wb') as f:
pickle.dump(obj, f, **kwargs)
def load_optuna_study(
path: Union[str, Path],
**kwargs,
) -> optuna.study.Study:
"""
Loads an Optuna study from the given path.
"""
with open(path, 'rb') as f:
return pickle.load(f, **kwargs)
## TORCH TENSOR
import torch
def save_torch_tensor(
obj: torch.Tensor,
path: Union[str, Path],
**kwargs,
) -> None:
"""
Saves a PyTorch tensor to the given path as a NumPy array.
"""
np.save(path, obj.detach().cpu().numpy(), **kwargs)
def load_torch_tensor(
path: Union[str, Path],
**kwargs,
) -> torch.Tensor:
"""
Loads a PyTorch tensor from the given path.
"""
return torch.from_numpy(np.load(path, **kwargs))
## REPR
def save_repr(
obj: object,
path: Union[str, Path],
**kwargs,
) -> None:
"""
Saves the repr of an object to the given path.
"""
with open(path, 'w') as f:
f.write(repr(obj))
def load_repr(
path: Union[str, Path],
**kwargs,
) -> object:
"""
Loads the repr of an object from the given path.
"""
with open(path, 'r') as f:
return f.read()
## HDBSCAN OBJECT (fast_hdbscan or legacy)
import fast_hdbscan
_hdbscan_class = fast_hdbscan.HDBSCAN
def save_hdbscan(
obj,
path: Union[str, Path],
**kwargs,
) -> None:
"""
Save a fast_hdbscan.HDBSCAN object by extracting serializable
attributes into a dict of numpy arrays and JSON scalars.
"""
import json
attrs = {}
## Use vars(obj) instead of dir(obj) to avoid triggering sklearn's
## __dir__ which calls hasattr on every attribute, including broken
## properties like condensed_tree_ that raise NameError in fast_hdbscan.
## Also include sklearn get_params() for constructor parameters.
all_attrs = dict(vars(obj))
try:
all_attrs.update(obj.get_params())
except Exception:
pass
for attr, val in sorted(all_attrs.items()):
if attr.startswith('_'):
continue
if callable(val):
continue
if val is None:
attrs[attr] = None
elif isinstance(val, np.ndarray):
attrs[attr] = {'_type': 'ndarray', 'data': val.tolist(), 'dtype': str(val.dtype)}
elif scipy.sparse.issparse(val):
csr = val.tocsr()
attrs[attr] = {
'_type': 'sparse',
'data': csr.data.tolist(),
'indices': csr.indices.tolist(),
'indptr': csr.indptr.tolist(),
'shape': list(csr.shape),
}
elif isinstance(val, (int, float, str, bool)):
attrs[attr] = val
elif isinstance(val, np.integer):
attrs[attr] = int(val)
elif isinstance(val, np.floating):
attrs[attr] = float(val)
## Skip non-serializable objects (CondensedTree, SingleLinkageTree)
Path(path).parent.mkdir(parents=True, exist_ok=True)
with open(path, 'w') as f:
json.dump(attrs, f)
def load_hdbscan(
path: Union[str, Path],
**kwargs,
) -> object:
"""
Load a fast_hdbscan.HDBSCAN-like dict from JSON.
Returns a dict of the stored attributes (not a live HDBSCAN object).
"""
import json
with open(path, 'r') as f:
content = f.read()
try:
attrs = json.loads(content)
except json.JSONDecodeError:
## Old format: file contains repr string, not JSON
return content
## Reconstruct numpy arrays and sparse matrices
for key, val in attrs.items():
if isinstance(val, dict) and val.get('_type') == 'ndarray':
attrs[key] = np.array(val['data'], dtype=val['dtype'])
elif isinstance(val, dict) and val.get('_type') == 'sparse':
attrs[key] = scipy.sparse.csr_array(
(np.array(val['data']), np.array(val['indices']), np.array(val['indptr'])),
shape=tuple(val['shape']),
)
return attrs
## SCIPY OPTIMIZE RESULT
import scipy.optimize
def save_optimize_result(
obj: scipy.optimize.OptimizeResult,
path: Union[str, Path],
**kwargs,
) -> None:
"""
Saves a scipy.optimize.OptimizeResult as JSON.
Extracts standard fields into a JSON-serializable dict.
"""
d = {}
for key in ('x', 'fun', 'nfev', 'nit', 'success', 'message'):
if key in obj:
val = obj[key]
if isinstance(val, np.ndarray):
d[key] = val.tolist()
elif isinstance(val, (np.integer,)):
d[key] = int(val)
elif isinstance(val, (np.floating,)):
d[key] = float(val)
elif isinstance(val, np.bool_):
d[key] = bool(val)
else:
d[key] = val
with open(path, 'w') as f:
json.dump(d, f)
def load_optimize_result(
path: Union[str, Path],
**kwargs,
) -> scipy.optimize.OptimizeResult:
"""
Loads a scipy.optimize.OptimizeResult from JSON.
"""
with open(path, 'r') as f:
d = json.load(f)
if 'x' in d and isinstance(d['x'], list):
d['x'] = np.array(d['x'])
return scipy.optimize.OptimizeResult(**d)
## PANDAS DATAFRAME
import pandas as pd
def save_pandas_dataframe(
obj: pd.DataFrame,
path: Union[str, Path],
**kwargs,
) -> None:
"""
Saves a Pandas DataFrame to the given path.
"""
## Save as a CSV file
obj.to_csv(path, index=True, **kwargs)
def load_pandas_dataframe(
path: Union[str, Path],
**kwargs,
) -> pd.DataFrame:
"""
Loads a Pandas DataFrame from the given path.
"""
## Load as a CSV file
return pd.read_csv(path, index_col=0, **kwargs)
roicat_module_tds = [rf.functions.Container(
type_name=type_name,
object_class=object_class,
suffix="roicat",
library="roicat",
versions_supported=[">=1.1", "<2"],
) for type_name, object_class in [
# ("data_suite2p", data_importing.Data_suite2p),
# ("data_caiman", data_importing.Data_caiman),
# ("data_roiextractors", data_importing.Data_roiextractors),
# ("data_roicat", data_importing.Data_roicat),
# ("aligner", alignment.Aligner),
# ("blurrer", blurring.ROI_Blurrer),
# ("roinet", ROInet.ROInet_embedder),
# ("swt", scatteringWaveletTransformer.SWT),
# ("similarity_graph", similarity_graph.ROI_graph),
# ("clusterer", clustering.Clusterer),
("toeplitz_conv", helpers.Toeplitz_convolution2d),
("convergence_checker_optuna", helpers.Convergence_checker_optuna),
("image_alignment_checker", helpers.ImageAlignmentChecker),
]]
## sparse_convolution.Toeplitz_convolution2d: store as JSON dict
## (its dtype attr is a type object that richfile can't serialize as-is)
import sparse_convolution
def _save_sparse_conv(obj, path, **kwargs):
d = {
'x_shape': list(obj.x_shape),
'k': obj.k.tolist(),
'mode': obj.mode,
'method': obj.method,
'dtype': np.dtype(obj.dtype).str,
}
with open(path, 'w') as f:
json.dump(d, f)
def _load_sparse_conv(path, **kwargs):
with open(path, 'r') as f:
d = json.load(f)
return sparse_convolution.Toeplitz_convolution2d(
x_shape=tuple(d['x_shape']),
k=np.array(d['k']),
mode=d['mode'],
method=d['method'],
dtype=np.dtype(d['dtype']),
)
sparse_conv_dict = {
'type_name': 'sparse_conv',
'object_class': sparse_convolution.Toeplitz_convolution2d,
'suffix': 'json',
'library': 'sparse_convolution',
'versions_supported': [],
'function_save': _save_sparse_conv,
'function_load': _load_sparse_conv,
}
# roicat_module_tds = []
## SIMILARITY METRIC (dataclass → JSON)
from .tracking.similarity_graph import SimilarityMetric
def save_similarity_metric(
obj: SimilarityMetric,
path: Union[str, Path],
**kwargs,
) -> None:
"""Saves a SimilarityMetric dataclass as JSON via to_dict()."""
with open(path, 'w') as f:
json.dump(obj.to_dict(), f)
def load_similarity_metric(
path: Union[str, Path],
**kwargs,
) -> SimilarityMetric:
"""Loads a SimilarityMetric from a JSON dict."""
with open(path, 'r') as f:
d = json.load(f)
## Convert power_bounds from list back to tuple (JSON round-trip)
if 'power_bounds' in d and isinstance(d['power_bounds'], list):
d['power_bounds'] = tuple(d['power_bounds'])
return SimilarityMetric.from_dict(d)
type_dicts = [
{
"type_name": "numpy_array",
"function_load": load_npy_array,
"function_save": save_npy_array,
"object_class": np.ndarray,
"suffix": "npy",
"library": "numpy",
"versions_supported": [],
},
{
"type_name": "numpy_scalar",
"function_load": load_npy_array,
"function_save": save_npy_array,
"object_class": np.number,
"suffix": "npy",
"library": "numpy",
"versions_supported": [],
},
{
"type_name": "scipy_sparse_array",
"function_load": load_sparse_array,
"function_save": save_sparse_array,
"object_class": scipy.sparse.spmatrix,
"suffix": "npz",
"library": "scipy",
"versions_supported": [],
},
## scipy >= 1.14 returns csr_array instead of csr_matrix from many
## operations. csr_array inherits from sparray, not spmatrix.
{
"type_name": "scipy_sparray",
"function_load": load_sparse_array,
"function_save": save_sparse_array,
"object_class": scipy.sparse.sparray,
"suffix": "npz",
"library": "scipy",
"versions_supported": [],
},
{
"type_name": "json_dict",
"function_load": load_json_dict,
"function_save": save_json_dict,
"object_class": JSON_Dict,
"suffix": "json",
"library": "python",
"versions_supported": [],
},
{
"type_name": "json_list",
"function_load": load_json_list,
"function_save": save_json_list,
"object_class": JSON_List,
"suffix": "json",
"library": "python",
"versions_supported": [],
},
{
"type_name": "optuna_study",
"function_load": load_optuna_study,
"function_save": save_optuna_study,
"object_class": optuna.study.Study,
"suffix": "optuna",
"library": "optuna",
"versions_supported": [],
},
{
"type_name": "torch_tensor",
"function_load": load_torch_tensor,
"function_save": save_torch_tensor,
"object_class": torch.Tensor,
"suffix": "npy",
"library": "torch",
"versions_supported": [],
},
{
"type_name": "model_swt",
"function_load": load_repr,
"function_save": save_repr,
"object_class": Model_SWT,
"suffix": "swt",
"library": "onnx2torch",
"versions_supported": [],
},
{
"type_name": "torch_module",
"function_load": load_repr,
"function_save": save_repr,
"object_class": torch.nn.Module,
"suffix": "torch_module",
"library": "torch",
"versions_supported": [],
},
{
"type_name": "torch_sequence",
"function_load": load_repr,
"function_save": save_repr,
"object_class": torch.nn.Sequential,
"suffix": "torch_sequence",
"library": "torch",
"versions_supported": [],
},
{
"type_name": "torch_dataset",
"function_load": load_repr,
"function_save": save_repr,
"object_class": torch.utils.data.Dataset,
"suffix": "torch_dataset",
"library": "torch",
"versions_supported": [],
},
{
"type_name": "torch_dataloader",
"function_load": load_repr,
"function_save": save_repr,
"object_class": torch.utils.data.DataLoader,
"suffix": "torch_dataloader",
"library": "torch",
"versions_supported": [],
},
*([{
"type_name": "hdbscan",
"function_load": load_hdbscan,
"function_save": save_hdbscan,
"object_class": _hdbscan_class,
"suffix": "json",
"library": "fast_hdbscan",
"versions_supported": [],
}] if _hdbscan_class is not None else []),
{
"type_name": "pandas_dataframe",
"function_load": load_pandas_dataframe,
"function_save": save_pandas_dataframe,
"object_class": pd.DataFrame,
"suffix": "csv",
"library": "pandas",
"versions_supported": [],
},
{
"type_name": "scipy_optimize_result",
"function_load": load_optimize_result,
"function_save": save_optimize_result,
"object_class": scipy.optimize.OptimizeResult,
"suffix": "json",
"library": "scipy",
"versions_supported": [],
},
{
"type_name": "similarity_metric",
"function_load": load_similarity_metric,
"function_save": save_similarity_metric,
"object_class": SimilarityMetric,
"suffix": "json",
"library": "roicat",
"versions_supported": [],
},
] + [t.get_property_dict() for t in roicat_module_tds] + [sparse_conv_dict]
[self.register_type_from_dict(d) for d in type_dicts]
######################################
######## CUSTOM DATA CLASSES #########
######################################
[docs]
class JSON_Dict(dict):
def __init__(self, *args, **kwargs):
super(JSON_Dict, self).__init__(*args, **kwargs)
[docs]
class JSON_List(list):
def __init__(self, *args, **kwargs):
super(JSON_List, self).__init__(*args, **kwargs)
## Wrapper for SWT
[docs]
class Model_SWT(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super(Model_SWT, self).__init__()
self.add_module('model', model)
[docs]
def forward(self, x):
return self.model(x)
[docs]
def make_session_bool(n_roi: np.ndarray,) -> np.ndarray:
"""
Generates a boolean array representing ROIs (Region Of Interest) per session from an array of ROI counts.
Args:
n_roi (np.ndarray):
Array representing the number of ROIs per session.
*shape*: *(n_sessions,)*
Returns:
(np.ndarray):
session_bool (np.ndarray):
Boolean array of shape *(n_roi_total, n_session)* where each column represents a session
and each row corresponds to an ROI.
Example:
.. highlight:: python
.. code-block:: python
n_roi = np.array([3, 4, 2])
session_bool = make_session_bool(n_roi)
"""
n_roi_total = np.sum(n_roi)
r = np.arange(n_roi_total, dtype=np.int64)
n_roi_cumsum = np.concatenate([[0], np.cumsum(n_roi)])
session_bool = np.vstack([(b_lower <= r) * (r < b_upper) for b_lower, b_upper in zip(n_roi_cumsum[:-1], n_roi_cumsum[1:])]).T
return session_bool
[docs]
def split_iby_session(
x: Any,
n_roi_per_session: Union[np.ndarray, List[int]],
):
"""
Splits an array or iterable into a list of arrays or iterables based on the
number of ROIs per session.
Args:
arr (Any):
Array to split.
n_roi_per_session (Union[np.ndarray, List[int]]):
Number of ROIs per session.
Returns:
(List[Any]):
List of arrays split by session.
"""
return [x[sum(n_roi_per_session[:ii]):sum(n_roi_per_session[:ii+1])] for ii in range(len(n_roi_per_session))]
##########################################################################################################################
############################################### UCID handling ############################################################
##########################################################################################################################
[docs]
def check_dataStructure__list_ofListOrArray_ofDtype(
lolod: Union[List[List[Union[int, float]]], List[np.ndarray]],
dtype: Type = np.int64,
fix: bool = True,
verbose: bool = True,
) -> Union[List[List[Union[int, float]]], List[np.ndarray]]:
"""
Verifies and optionally corrects the data structure of 'lolod' (list of list
of dtype).
The structure should be a list of lists of dtypes or a list of numpy arrays
of dtypes.
Args:
lolod (Union[List[List[Union[int, float]]], List[np.ndarray]]):
* The data structure to check. It should be a list of lists of
dtypes or a list of numpy arrays of dtypes.
dtype (Type):
* The expected dtype of the elements in 'lolod'. (Default is
``np.int64``)
fix (bool):
* If ``True``, attempts to correct the data structure if it is not
as expected. The corrections are as follows: \n
* If 'lolod' is an array, it will be cast to [lolod]
* If 'lolod' is a numpy object, it will be cast to
[np.array(lolod, dtype=dtype)]
* If 'lolod' is a list of lists of numbers (int or float), it
will be cast to [np.array(lod, dtype=dtype) for lod in lolod]
* If 'lolod' is a list of arrays of wrong dtype, it will be cast
to [np.array(lod, dtype=dtype) for lod in lolod] \n
* If ``False``, raises an error if the structure is not as expected.
(Default is ``True``)
verbose (bool):
* If ``True``, prints warnings when the structure is not as expected
and is corrected. (Default is ``True``)
Returns:
(Union[List[List[Union[int, float]]], List[np.ndarray]]):
lolod (Union[List[List[Union[int, float]]], List[np.ndarray]]):
The verified or corrected data structure.
"""
## switch case for if it is a list or np.ndarray
if isinstance(lolod, list):
## switch case for if the elements are lists or np.ndarray or numbers (int or float) or dtypes
if all([isinstance(lod, list) for lod in lolod]):
## switch case for if the elements are numbers (int or float) or dtype or other
if all([all([isinstance(l, (int, float, np.integer, np.floating)) for l in lod]) for lod in lolod]):
if fix:
print(f'ROICaT WARNING: lolod is a list of lists of numbers (int or float). Converting to np.ndarray.') if verbose else None
lolod = [np.array(lod, dtype=dtype) for lod in lolod]
else:
raise ValueError(f'ROICaT ERROR: lolod is a list of lists of numbers (int or float).')
elif all([all([isinstance(l, dtype) for l in lod]) for lod in lolod]):
pass
else:
raise ValueError(f'ROICaT ERROR: lolod is a list of lists, but the elements are not all numbers (int or float) or dtype.')
elif all([isinstance(lod, np.ndarray) for lod in lolod]):
## switch case for if the elements are numbers (any non-object numpy dtype) or dtype or other
if all([all([np.issubdtype(lod.dtype, dtype) for lod in lolod])]):
pass
if all([all([not np.issubdtype(lod.dtype, np.object_) for lod in lolod])]):
if fix:
print(f'ROICaT WARNING: lolod is a list of np.ndarray of numbers (int or float). Converting to np.ndarray.') if verbose else None
lolod = [np.array(lod, dtype=dtype) for lod in lolod]
else:
raise ValueError(f'ROICaT ERROR: lolod is a list of np.ndarray of numbers (int or float).')
else:
raise ValueError(f'ROICaT ERROR: lolod is a list of np.ndarray, but the elements are not all numbers (int or float) or dtype.')
elif all([isinstance(lod, (int, float)) for lod in lolod]):
if fix:
print(f'ROICaT WARNING: lolod is a list of numbers (int or float). Converting to np.ndarray.') if verbose else None
lolod = [np.array(lod, dtype=dtype) for lod in lolod]
else:
raise ValueError(f'ROICaT ERROR: lolod is a list of numbers (int or float).')
elif all([isinstance(lod, dtype) for lod in lolod]):
if fix:
print(f'ROICaT WARNING: lolod is a list of dtype. Converting to np.ndarray.') if verbose else None
lolod = [np.array(lolod, dtype=dtype)]
else:
raise ValueError(f'ROICaT ERROR: lolod is a list of dtype.')
else:
raise ValueError(f'ROICaT ERROR: lolod is a list, but the elements are not all lists or np.ndarray or numbers (int or float).')
elif isinstance(lolod, np.ndarray):
## switch case for if the elements are numbers (any non-object numpy dtype) or dtype or other
if np.issubdtype(lolod.dtype, dtype):
if fix:
print(f'ROICaT WARNING: lolod is a np.ndarray of dtype. Converting to list of np.ndarray of dtype.') if verbose else None
lolod = [lolod]
elif not np.issubdtype(lolod.dtype, np.object_):
if fix:
print(f'ROICaT WARNING: lolod is a np.ndarray of numbers (int or float). Converting to list of np.ndarray of dtype.') if verbose else None
lolod = [np.array(lolod, dtype=dtype)]
else:
raise ValueError(f'ROICaT ERROR: lolod is a np.ndarray of numbers (int or float).')
else:
raise ValueError(f'ROICaT ERROR: lolod is a np.ndarray, but the elements are not all numbers (int or float) or dtype.')
else:
raise ValueError(f'ROICaT ERROR: lolod is not a list or np.ndarray.')
return lolod
[docs]
def mask_UCIDs_with_iscell(
ucids: List[Union[List[int], np.ndarray]],
iscell: List[Union[List[bool], np.ndarray]]
) -> List[Union[List[int], np.ndarray]]:
"""
Masks the UCIDs with the **iscell** array. If ``iscell`` is False, then the
UCID is set to -1.
Args:
ucids (List[Union[List[int], np.ndarray]]):
List of lists of UCIDs for each session.\n
Shape outer list: *(n_sessions,)*\n
Shape inner list: *(n_roi_in_session,)*
iscell (List[Union[List[bool], np.ndarray]]):
List of lists of boolean indicators for each UCID.\n
``True`` means that ROI is a cell, ``False`` means that ROI is not a
cell.\n
Shape outer list: *(n_sessions,)*\n
Shape inner list: *(n_roi_in_session,)*
Returns:
(List[Union[List[int], np.ndarray]]):
ucids_out (List[Union[List[int], np.ndarray]]):
Masked list of lists of UCIDs. Elements that are not cells are
set to -1 in each session.
"""
ucids_out = copy.deepcopy(ucids)
ucids_out = check_dataStructure__list_ofListOrArray_ofDtype(
lolod=ucids_out,
dtype=np.int64,
fix=True,
verbose=False,
)
iscell = check_dataStructure__list_ofListOrArray_ofDtype(
lolod=iscell,
dtype=bool,
fix=True,
verbose=False,
)
n_sesh = len(ucids)
for i_sesh in range(n_sesh):
ucids_out[i_sesh][~iscell[i_sesh]] = -1
return ucids_out
[docs]
def mask_UCIDs_by_label(
ucids: List[Union[List[int], np.ndarray]],
labels: Union[List[int], np.ndarray],
) -> List[Union[List[int], np.ndarray]]:
"""
Sets labels in the UCIDs to -1 if they are not present in the **labels**
array.\n
RH 2024
Args:
ucids (List[Union[List[int], np.ndarray]]):
List of lists of UCIDs for each session.\n
Shape outer list: *(n_sessions,)*\n
Shape inner list: *(n_roi_in_session,)*
labels (Union[List[int], np.ndarray]):
Array of labels to keep. All other labels are set to -1.
Shape: *(n_labels,)*
Returns:
(List[Union[List[int], np.ndarray]]):
ucids_out (List[Union[List[int], np.ndarray]]):
Masked list of lists of UCIDs. Elements that are not in the
**labels** array are set to -1 in each session.
Example:
.. highlight:: python
.. code-block:: python
ucids = [[1, 2, 3], [2, -1, 4], [3, 0, 5]]
labels = [2, 3]
ucids_out = mask_UCIDs_by_label(ucids, labels)
# ucids_out = [[-1, 2, 3], [2, -1, -1], [3, -1, -1]]
"""
ucids_out = copy.deepcopy(ucids)
ucids_out = check_dataStructure__list_ofListOrArray_ofDtype(
lolod=ucids_out,
dtype=np.int64,
fix=True,
verbose=False,
)
labels = np.array(labels, dtype=np.int64)
iscell = [np.isin(u_sesh, labels) for u_sesh in ucids_out]
ucids_out = mask_UCIDs_with_iscell(ucids_out, iscell)
return ucids_out
[docs]
def discard_UCIDs_with_fewer_matches(
ucids: List[Union[List[int], np.ndarray]],
n_sesh_thresh: Union[int, str] = 'all',
verbose: bool = True
) -> List[Union[List[int], np.ndarray]]:
"""
Discards UCIDs that do not appear in at least **n_sesh_thresh** sessions. If
``n_sesh_thresh='all'``, then only UCIDs that appear in all sessions are
kept.
Args:
ucids (List[Union[List[int], np.ndarray]]):
List of lists of UCIDs for each session.
n_sesh_thresh (Union[int, str]):
Number of sessions that a UCID must appear in to be kept. If
``'all'``, then only UCIDs that appear in all sessions are kept.
(Default is ``'all'``)
verbose (bool):
If ``True``, print verbose output. (Default is ``True``)
Returns:
(List[Union[List[int], np.ndarray]]):
ucids_out (List[Union[List[int], np.ndarray]]):
List of lists of UCIDs with UCIDs that do not appear in at least
**n_sesh_thresh** sessions set to -1.
"""
ucids_out = copy.deepcopy(ucids)
ucids_out = check_dataStructure__list_ofListOrArray_ofDtype(
lolod=ucids_out,
dtype=np.int64,
fix=True,
verbose=False,
)
n_sesh = len(ucids)
n_sesh_thresh = n_sesh if n_sesh_thresh == 'all' else n_sesh_thresh
assert isinstance(n_sesh_thresh, int)
ucids_unique = np.unique(np.concatenate(ucids_out, axis=0))
ucids_inAllSesh = [u for u in ucids_unique if np.array([np.isin(u, u_sesh) for u_sesh in ucids_out]).sum() >= n_sesh_thresh]
if verbose:
fraction = (np.unique(ucids_inAllSesh) >= 0).sum() / (ucids_unique >= 0).sum()
print(f'INFO: {fraction*100:.2f}% of UCIDs that appear in at least {n_sesh_thresh} sessions.')
ucids_out = [[val * np.isin(val, ucids_inAllSesh) - np.logical_not(np.isin(val, ucids_inAllSesh)) for val in u] for u in ucids_out]
return ucids_out
[docs]
def squeeze_UCID_labels(
ucids: List[Union[List[int], np.ndarray]],
return_array: bool = False,
) -> List[Union[List[int], np.ndarray]]:
"""
Squeezes the UCID labels. Finds all the unique UCIDs across all sessions,
then removes spaces in the UCID labels by mapping the unique UCIDs to new
values. Output UCIDs are contiguous integers starting at 0, and maintains
elements with UCID=-1.
Args:
ucids (List[Union[List[int], np.ndarray]]):
List of lists of UCIDs for each session.
return_array (bool):
If ``True``, then the output will be a numpy array.
(Default is ``False``)
Returns:
(List[Union[List[int], np.ndarray]]):
ucids_out (List[Union[List[int], np.ndarray]]):
List of lists of UCIDs with UCIDs that do not appear in at least
**n_sesh_thresh** sessions set to -1.
"""
ucids_out = copy.deepcopy(ucids)
ucids_out = check_dataStructure__list_ofListOrArray_ofDtype(
lolod=ucids_out,
dtype=np.int64,
fix=True,
verbose=False,
)
uniques_all = np.unique(np.concatenate(ucids_out, axis=0))
uniques_all = np.sort(uniques_all[uniques_all >= 0])
## make a mapping of the unique values to new values
# mapping = {old: new for old, new in zip(uniques_all, helpers.squeeze_integers(uniques_all))}
mapping = {old: new for old, new in zip(uniques_all, np.arange(len(uniques_all)))}
mapping.update({-1: -1})
## apply the mapping to the data
n_sesh = len(ucids_out)
for i_sesh in range(n_sesh):
ucids_out[i_sesh] = [int(mapping[val]) for val in ucids_out[i_sesh]]
if not return_array:
return ucids_out
else:
return [np.array(u) for u in ucids_out]
[docs]
def match_arrays_with_ucids(
arrays: Union[np.ndarray, List[np.ndarray]],
ucids: Union[List[np.ndarray], List[List[int]]],
return_indices: bool = False,
squeeze: bool = False,
force_sparse: bool = False,
prog_bar: bool = False,
) -> List[Union[np.ndarray, scipy.sparse.lil_array]]:
"""
Matches the indices of the arrays using the UCIDs. Array indices with UCIDs
corresponding to -1 are set to ``np.nan``. This is useful for aligning
Fluorescence and Spiking data across sessions using UCIDs.
Args:
arrays (Union[np.ndarray, List[np.ndarray]]):
List of numpy arrays for each session. Matching is done along the
first dimension.
ucids (Union[List[np.ndarray], List[List[int]]]):
List of lists of UCIDs for each session.
return_indices (bool):
If ``True``, then the indices of the UCIDs will also be returned.
The indices will be of dtype np.float32 because it may contain NaNs.
(Default is ``False``)
squeeze (bool):
If ``True``, then UCIDs are squeezed to be contiguous integers.
(Default is ``False``)
force_sparse (bool):
If ``True``, then the output will be a list of sparse matrices.
(Default is ``False``)
prog_bar (bool):
If ``True``, then a progress bar will be displayed. (Default is
``False``)
Returns:
(List[Union[np.ndarray, scipy.sparse.lil_array]]):
arrays_out (List[Union[np.ndarray, scipy.sparse.lil_array]]):
List of arrays for each session. Array indices with UCIDs
corresponding to -1 are set to ``np.nan``. Each array will have
shape: *(n_ucids if squeeze==True OR max_ucid if squeeze==False,
*array.shape[1:])*. UCIDs will be used as the index of the first
dimension.
"""
import scipy.sparse
arrays = [arrays] if not isinstance(arrays, list) else arrays
ucids_tu = check_dataStructure__list_ofListOrArray_ofDtype(
lolod=ucids,
dtype=np.int64,
fix=True,
verbose=False,
)
## Error if dtype is not NaN compatible
if not np.issubdtype(arrays[0].dtype, np.floating):
raise ValueError(f'ROICaT ERROR: This function requires inputs to be of a dtype that is compatible with NaNs, like np.floating types: np.float32, np.float64, etc.')
## Squeeze UCIDs
ucids_tu = squeeze_UCID_labels(ucids_tu) if squeeze else ucids_tu
# max_ucid = (np.unique(np.concatenate(ucids_tu, axis=0)) >= 0).max()
max_ucid = (np.unique(np.concatenate(ucids_tu, axis=0))).max().astype(int) + 1
dicts_ucids = [{u: i for i, u in enumerate(u_sesh)} for u_sesh in ucids_tu]
## make ndarrays filled with np.nan for each session
if isinstance(arrays[0], np.ndarray) and not force_sparse:
arrays_out = [np.full((max_ucid, *a.shape[1:]), np.nan, dtype=arrays[0].dtype) for a in arrays]
elif scipy.sparse.issparse(arrays[0]) or force_sparse:
arrays_out = [scipy.sparse.lil_array((max_ucid, *a.shape[1:]), dtype=a.dtype) for a in arrays]
else:
raise ValueError(f'ROICaT ERROR: arrays[0] is not a numpy array or scipy.sparse matrix.')
## fill in the arrays with the data
n_sesh = len(arrays)
for i_sesh in tqdm(range(n_sesh), disable=not prog_bar):
for u, idx in dicts_ucids[i_sesh].items():
if u >= 0:
arrays_out[i_sesh][u] = arrays[i_sesh][idx]
# Use the original array lengths so we can recover the correct source indices
# even when the per-session arrays contain more ROIs than there are UCIDs.
lens_arrays = [a.shape[0] if hasattr(a, 'shape') else len(a) for a in arrays]
if not return_indices:
return arrays_out
else:
return arrays_out, match_arrays_with_ucids(
arrays=[np.arange(lens_arrays[ii], dtype=np.float32) for ii, a in enumerate(arrays)],
ucids=ucids,
return_indices=False,
squeeze=squeeze,
force_sparse=False,
prog_bar=False,
)
[docs]
def match_arrays_with_ucids_inverse(
arrays: Union[np.ndarray, List[np.ndarray]],
ucids: Union[List[np.ndarray], List[List[int]]],
unsqueeze: bool = True,
) -> List[Union[np.ndarray, scipy.sparse.lil_array]]:
"""
Inverts the matching of the indices of the arrays using the UCIDs. Arrays
should have indices that correspond to the UCID values. The return will be a
list of arrays with indices that correspond to the original indices of the
arrays / ucids. Essentially, this function undoes the matching done by
match_arrays_with_ucids().
Args:
arrays (Union[np.ndarray, List[np.ndarray]]):
List of numpy arrays for each session.
ucids (Union[List[np.ndarray], List[List[int]]]):
List of lists of UCIDs for each session.
unsqueeze (bool):
If ``True``, then this algorithm assumes that the arrays were
squeezed to remove unused UCIDs. This corresponds to and should
match the argument ``squeeze`` used in match_arrays_with_ucids().
Returns:
(List[Union[np.ndarray, scipy.sparse.lil_array]]):
arrays_out (List[Union[np.ndarray, scipy.sparse.lil_array]]):
List of arrays with indices that correspond to the original
indices of the arrays / ucids.
"""
arrays = [arrays] if not isinstance(arrays, list) else arrays
## Make a mapping of the UCIDs to the original indices ('aranges_matched')
ucids_clean = copy.deepcopy(ucids)
ucids_clean = check_dataStructure__list_ofListOrArray_ofDtype(
lolod=ucids_clean,
dtype=np.float32,
fix=True,
verbose=False,
)
aranges = [np.arange(len(u), dtype=np.float32) for u in ucids_clean]
aranges_matched = match_arrays_with_ucids(
arrays=aranges,
ucids=ucids_clean,
squeeze=False,
)
## Make sure that unsqueeze is consistent with the arrays
flag_same_len = all([len(u) == len(a) for u, a in zip(aranges_matched, arrays)])
if unsqueeze == False:
assert flag_same_len == True
else:
assert flag_same_len == False
# Unsqueeze arrays
if unsqueeze:
idx_unsq = [(np.cumsum(~np.isnan(a)) - 1).astype(np.float32) for a in aranges_matched]
for ii, a in enumerate(aranges_matched):
idx_unsq[ii][np.isnan(a)] = np.nan
arrays_unsq = [helpers.index_with_nans(a, idx) for a, idx in zip(arrays, idx_unsq)]
else:
arrays_unsq = arrays
## Invert the matching
def negOne_to_nan(x):
tmp = np.array(x, dtype=np.float32)
np.place(arr=tmp, mask=tmp == -1, vals=np.nan)
return tmp
ucids_clean_nan = [negOne_to_nan(u) for u in ucids_clean]
arrays_inv = [helpers.index_with_nans(a, o) for a, o in zip(arrays_unsq, ucids_clean_nan)]
return arrays_inv
[docs]
def labels_to_labelsBySession(labels, n_roi_bySession):
"""
Converts a list of labels to a list of lists of labels by session.
RH 2024
Args:
labels (list or np.ndarray):
List of labels.
n_roi_bySession (list or np.ndarray):
Number of ROIs by session.
Returns:
(list):
List of lists of labels by session.
"""
assert isinstance(labels, (list, np.ndarray)), f'labels is not a list or np.ndarray. labels={labels}'
assert isinstance(n_roi_bySession, (list, np.ndarray)), f'n_roi_bySession is not a list or np.ndarray. n_roi_bySession={n_roi_bySession}'
labels = np.array(labels)
n_roi_bySession = np.array(n_roi_bySession, dtype=np.int64)
assert labels.ndim == 1, f'labels.ndim={labels.ndim}, but should be 1.'
assert n_roi_bySession.ndim == 1, f'n_roi_bySession.ndim={n_roi_bySession.ndim}, but should be 1.'
assert np.sum(n_roi_bySession) == len(labels), f'np.sum(n_roi_bySession)={np.sum(n_roi_bySession)} != len(labels)={len(labels)}'
labels_bySession = split_iby_session(x=labels, n_roi_per_session=n_roi_bySession)
return labels_bySession
[docs]
def invert_ucids(
ucids: Union[np.ndarray, list],
max_ucid: int = None,
) -> Union[np.ndarray, list]:
"""
Invert UCIDs to make ucids_inverse where ucids_inverse[i] =
argwhere(ucids == i) for all i in range(len(ucids)).
Elements with UCID=-1 are discarded. Missing ucid values are set to -1.
RH 2025
Args:
ucids (Union[np.ndarray, list]):
UCIDs to invert. Should be a 1D array or list of integers.
Should be the ucids from a single session.
max_ucid (int, optional):
Maximum UCID value to use. If not provided, it will be inferred from
the input UCIDs. If provided, it should be greater than or equal to
the maximum UCID in the input. This is useful if you are combining
multiple sessions with different UCID ranges.
(Default is ``None``)
Returns:
(Union[np.ndarray, list]):
Inverted UCIDs where ucids_inverse[i] = argwhere(ucids == i) for all
i in range(len(ucids)).
"""
if isinstance(ucids, list):
ucids = np.array(ucids, dtype=np.int64)
flag_list = True
elif not isinstance(ucids, np.ndarray):
raise ValueError(f'ROICaT ERROR: ucids must be a list or numpy array, but got {type(ucids)}.')
else:
flag_list = False
if max_ucid is None:
max_ucid = np.max(ucids) if len(ucids) > 0 else 0
elif max_ucid < np.max(ucids):
raise ValueError(f'ROICaT ERROR: n_ucids={max_ucid} is less than the maximum UCID in the input ucids={np.max(ucids)}. Please provide a larger max_ucid value.')
ucids_inverse = np.full((max_ucid + 1,), -1, dtype=np.int64)
for i in range(len(ucids)):
if ucids[i] >= 0:
if ucids_inverse[ucids[i]] == -1:
ucids_inverse[ucids[i]] = i
else:
raise ValueError(f'ROICaT ERROR: Duplicate UCID {ucids[i]} found at indices {ucids_inverse[ucids[i]]} and {i}. UCIDs must be unique.')
if flag_list:
ucids_inverse = ucids_inverse.tolist()
else:
ucids_inverse = ucids_inverse.astype(np.int64)
return ucids_inverse