Source code for roicat.util

from pathlib import Path
import warnings
import copy
from typing import Dict, Any, Optional, Union, List, Tuple, Callable, Iterable, Iterator, Type
import datetime
import collections
import importlib

import numpy as np
import scipy.sparse
from tqdm.auto import tqdm
import torch

import richfile as rf

from . import helpers


## Floor value for sparse normalization denominators (row-max, row-sum, etc.).
## Values below this threshold are treated as zero for normalization purposes,
## preventing amplification of sub-noise-floor signal. This avoids both:
## (1) division-by-zero producing inf, and (2) amplifying numerical artifacts
## (e.g. 1e-15 values) to full scale during max-normalization of spatial footprints.
SPARSE_NORMALIZATION_FLOOR = 1e-12


[docs] def get_roicat_version() -> str: """ Retrieves the version of the roicat package. Returns: (str): version (str): The version of the roicat package. """ return importlib.metadata.version('roicat')
[docs] def get_default_parameters( pipeline='tracking', path_defaults=None ): """ This function returns a dictionary of parameters that can be used to run different pipelines. RH 2023 Args: pipeline (str): The name of the pipeline to use. Options: \n * 'tracking': Tracking pipeline. \n * 'classification_inference': Classification inference pipeline (TODO). \n * 'classification_training': Classification training pipeline (TODO). \n * 'model_training': Model training pipeline (TODO). \n path_defaults (str): A path to a yaml file containing a parameters dictionary. The parameters from the file will be loaded as is. If None, the default parameters will be used. Returns: (dict): params (dict): A dictionary containing the default parameters. """ if path_defaults is not None: defaults = helpers.yaml_load(path_defaults) else: defaults = { 'general': { 'use_GPU': True, 'verbose': True, 'random_seed': None, }, 'data_loading': { 'data_kind': 'data_suite2p', ## Can be 'suite2p', 'roiextractors', or 'roicat'. See documentation and/or notebook on custom data loading for more details. 'dir_outer': None, ## directory where directories containing below 'pathSuffixTo...' are 'reMatch_in_path': None, ## regular expression string that must be found within the parent path of any discovered files. See helpers.find_paths 'common': { 'um_per_pixel': 1.0, ## Number of microns per pixel for the imaging dataset. Doesn't need to be exact. Used for resizing the ROIs. Check the images of the resized ROIs to tweak. 'centroid_method': 'centerOfMass', ## Can be 'centerOfMass' or 'median'. 'out_height_width': [36,36], ## Height and width of the small ROI_images. Should generally be tuned slightly bigger than the largest ROIs. Leave if uncertain or if ROIs are small enough to fit in the default size. }, 'data_suite2p': { 'new_or_old_suite2p': 'new', ## Can be 'new' or 'old'. 'new' is for the Python version of Suite2p, 'old' is for the MATLAB version. 'type_meanImg': 'meanImgE', ## Can be 'meanImg' or 'meanImgE'. 'meanImg' is the mean image of the dataset, 'meanImgE' is the mean image of the dataset after contrast enhancement. }, 'data_roicat': { 'filename_search': r'data_roicat.richfile', ## Name stem of the single file (as a regex search string) in 'dir_outer' to look for. The files should be saved Data_roicat object. }, }, 'alignment': { 'initialization': { 'use_match_search': True, ## Whether or not to use our match search algorithm to initialize the alignment. 'all_to_all': False, ## Force the use of our match search algorithm for all-pairs matching. Much slower (False: O(N) vs. True: O(N^2)), but more accurate. 'radius_in': 4.0, ## Value in micrometers used to define the maximum shift/offset between two images that are considered to be aligned. Larger means more lenient alignment. 'radius_out': 20.0, ## Value in micrometers used to define the minimum shift/offset between two images that are considered to be misaligned. 'z_threshold': 4.0, ## Z-score required to define two images as aligned. Larger values results in more stringent alignment requirements. }, 'augment': { 'normalize_FOV_intensities': True, ## Whether or not to normalize the FOV_images to the max value across all FOV images. 'roi_FOV_mixing_factor': 0.5, ## default: 0.5. Fraction of the max intensity projection of ROIs that is added to the FOV image. 0.0 means only the FOV_images, larger values mean more of the ROIs are added. 'use_CLAHE': True, ## Whether or not to use 'Contrast Limited Adaptive Histogram Equalization'. Useful if params['importing']['type_meanImg'] is not a contrast enhanced image (like 'meanImgE' in Suite2p) 'CLAHE_grid_block_size': 10, ## Size of the block size for the grid for CLAHE. Smaller values means more local contrast enhancement. 'CLAHE_clipLimit': 1.0, ## Clipping limit for CLAHE. Higher values mean more contrast. 'CLAHE_normalize': True, ## Whether or not to normalize the CLAHE image. }, 'fit_geometric': { 'template': 0.5, ## Which session to use as a registration template. If input is float (ie 0.0, 0.5, 1.0, etc.), then it is the fractional position of the session to use; if input is int (ie 1, 2, 3), then it is the index of the session to use (0-indexed) 'template_method': 'image', ## Can be 'sequential' or 'image'. If 'sequential', then the template is the FOV_image of the previous session. If 'image', then the template is the FOV_image of the session specified by 'template'. 'mask_borders': [0, 0, 0, 0], ## Number of pixels to mask from the borders of the FOV_image. (top, bottom, left, right). Useful for removing artifacts from the edges of the FOV_image. 'method': 'RoMa', ## Accuracy order (best to worst): RoMa (by far, but slow without a GPU), LoFTR, DISK_LightGlue, ECC_cv2, (the following are not recommended) SIFT, ORB 'kwargs_method': { 'RoMa': { 'model_type': 'outdoor', 'n_points': 10000, ## Higher values mean more points are used for the registration. Useful for larger FOV_images. Larger means slower. 'batch_size': 1000, }, 'DISK_LightGlue': { 'num_features': 3000, ## Number of features to extract and match. I've seen best results around 2048 despite higher values typically being better. 'threshold_confidence': 0.2, ## Higher values means fewer but better matches. }, 'LoFTR': { 'model_type': 'indoor_new', 'threshold_confidence': 0.2, ## Higher values means fewer but better matches. }, 'ECC_cv2': { 'mode_transform': 'euclidean', ## Must be one of {'translation', 'affine', 'euclidean', 'homography'}. See cv2 documentation on findTransformECC for more details. 'n_iter': 200, 'termination_eps': 1e-09, ## Termination criteria for the registration algorithm. See documentation for more details. 'gaussFiltSize': 1, ## Size of the gaussian filter used to smooth the FOV_image before registration. Larger values mean more smoothing. 'auto_fix_gaussFilt_step': 10, ## If the registration fails, then the gaussian filter size is reduced by this amount and the registration is tried again. }, 'PhaseCorrelation': { 'bandpass_freqs': [1, 30], 'order': 5, }, 'SIFT': { 'nfeatures': 10000, 'contrastThreshold': 0.04, 'edgeThreshold': 10, 'sigma': 1.6, }, 'ORB': { 'nfeatures': 1000, 'scaleFactor': 1.2, 'nlevels': 8, 'edgeThreshold': 31, 'firstLevel': 0, 'WTA_K': 2, 'scoreType': 0, 'patchSize': 31, 'fastThreshold': 20, }, 'NullRegistration': {}, ## No registration, no warping. }, 'constraint': 'affine', ## Must be one of {'rigid', 'euclidean', 'similarity', 'affine', 'homography'}. Choose constraint based on expected changes in images; use the simplest constraint that is applicable. 'kwargs_RANSAC': { ## Parameters related to the RANSAC algorithm used for point/descriptor based registration methods. 'inl_thresh': 3.0, ## Threshold for the inliers. Larger values mean more points are considered inliers. 'max_iter': 100, ## Maximum number of iterations for the RANSAC algorithm. 'confidence': 0.99, ## Confidence level for the RANSAC algorithm. Larger values mean more points are considered inliers. }, }, 'fit_nonrigid': { 'template': 0.5, ## Which session to use as a registration template. If input is float (ie 0.0, 0.5, 1.0, etc.), then it is the fractional position of the session to use; if input is int (ie 1, 2, 3), then it is the index of the session to use (0-indexed) 'template_method': 'image', ## Can be 'sequential' or 'image'. If 'sequential', then the template is the FOV_image of the previous session. If 'image', then the template is the FOV_image of the session specified by 'template'. 'method': 'DeepFlow', 'kwargs_method': { 'RoMa': { 'model_type': 'outdoor', }, 'DeepFlow': {}, 'OpticalFlowFarneback': { 'pyr_scale': 0.7, 'levels': 5, 'winsize': 128, 'iterations': 15, 'poly_n': 5, 'poly_sigma': 1.5, }, 'NullRegistration': {}, }, }, 'transform_ROIs': { 'normalize': True, ## If True, normalize the spatial footprints to have a sum of 1. }, }, 'blurring': { 'kernel_halfWidth': 2.0, ## Half-width of the cosine kernel used for blurring. Set value based on how much you think the ROIs move from session to session. }, 'ROInet': { 'network': { 'download_method': 'check_local_first', ## Check to see if a model has already been downloaded to the location (will skip if hash matches) 'download_url': 'https://osf.io/x3fd2/download', ## URL of the model 'download_hash': '7a5fb8ad94b110037785a46b9463ea94', ## Hash of the model file 'forward_pass_version': 'latent', ## How the data is passed through the network }, 'dataloader': { 'jit_script_transforms': False, ## (advanced) Whether or not to use torch.jit.script to speed things up 'batchSize_dataloader': 8, ## (advanced) PyTorch dataloader batch_size 'pinMemory_dataloader': True, ## (advanced) PyTorch dataloader pin_memory 'numWorkers_dataloader': -1, ## (advanced) PyTorch dataloader num_workers. -1 is all cores. 'persistentWorkers_dataloader': True, ## (advanced) PyTorch dataloader persistent_workers 'prefetchFactor_dataloader': 2, ## (advanced) PyTorch dataloader prefetch_factor }, }, 'SWT': { 'kwargs_Scattering2D': {'J': 2, 'L': 12}, ## 'J' is the number of convolutional layers. 'L' is the number of wavelet angles. 'batch_size': 100, ## Batch size for each iteration (smaller is less memory but slower) }, 'similarity_graph': { 'sparsification': { 'n_workers': -1, ## Number of CPU cores to use. -1 for all. 'block_height': 128, ## size of a block 'block_width': 128, ## size of a block 'algorithm_nearestNeigbors_spatialFootprints': 'brute', ## algorithm used to find the pairwise similarity for s_sf. ('brute' is slow but exact. See docs for others.) }, 'compute_similarity': { 'spatialFootprint_maskPower': 1.0, ## An exponent to raise the spatial footprints to to care more or less about bright pixels }, 'normalization': { 'k_max': 100, ## Maximum number of nearest neighbors * n_sessions to consider for the normalizing distribution 'k_min': 10, ## Minimum number of nearest neighbors * n_sessions to consider for the normalizing distribution 'algo_NN': 'kd_tree', ## Nearest neighbors algorithm to use }, }, 'clustering': { 'mixing_method': 'automatic', ## 'automatic' (NB calibration + freeze-sigmoid DE) or 'manual' 'parameters_automatic_mixing': { 'n_bins': None, ## Number of bins for histograms. None = heuristic. 'smoothing_window_bins': None, ## Smoothing window for distributions. None = heuristic. 'subsample_pairs': None, ## Subsample this many pairs for speedup. None = use all. 'bounds_findParameters': { 'power_nn': [0.0, 2.], ## Bounds for the exponent applied to s_nn 'power_swt': [0.0, 2.], ## Bounds for the exponent applied to s_swt 'p_norm': [-5, -0.1], ## Bounds for the p-norm (Minkowski) mixing parameter }, 'de_kwargs': { 'maxiter': 100, ## Max DE generations 'tol': 1e-6, ## Convergence tolerance 'popsize': 15, ## Population size multiplier 'mutation': [0.5, 1.5], ## DE mutation range 'recombination': 0.7, ## DE crossover probability 'polish': True, ## L-BFGS-B polish after DE }, }, 'parameters_manual_mixing': { 'power_sf': 1.0, ## s_sf**power_sf (Higher values means clustering is more sensitive to spatial overlap of ROIs) 'power_nn': 0.5, ## s_nn**power_nn (Higher values means clustering is more sensitive to visual similarity of ROIs) 'power_swt': 0.5, ## s_swt**power_swt (Higher values means clustering is more sensitive to visual similarity of ROIs) 'p_norm': -1.0, ## norm([s_sf, s_nn, s_swt], p=p_norm) (Higher values means clustering requires all similarity metrics to be high) 'sig_sf_kwargs': None, ## Sigmoid parameters for s_sf (mu is the center, b is the slope) 'sig_nn_kwargs': {'mu': 0.5, 'b': 1.0}, ## Sigmoid parameters for s_nn (mu is the center, b is the slope) 'sig_swt_kwargs': {'mu': 0.5, 'b': 1.0}, ## Sigmoid parameters for s_swt (mu is the center, b is the slope) }, 'pruning': { 'd_cutoff': None, ## Optionally manually specify a distance cutoff 'stringency': 1.0, ## How to scale the d_cuttoff. This is a scalaing factor. Smaller numbers result in more pruning. 'convert_to_probability': False, ## Whether or not to convert the similarity matrix and distance matrix to a probability matrix }, 'cluster_method': { 'method': 'automatic', ## 'automatic', 'hdbscan', or 'sequential_hungarian'. 'automatic': selects which clustering algorithm to use (generally if n_sessions >=8 then hdbscan, else sequential_hungarian) 'n_sessions_switch': 6, ## Number of sessions to switch from sequential_hungarian to hdbscan }, 'hdbscan': { 'min_cluster_size': 2, ## Minimum number of ROIs that can be considered a 'cluster' 'max_cluster_size': None, ## Maximum cluster size. None defaults to n_sessions (one ROI per session). 'min_samples': None, ## Number of neighbors for core-point density. None defaults to min_cluster_size. Lower values reduce noise points. 'cluster_selection_method': 'eom', ## (advanced) Method of cluster selection for HDBSCAN (see hdbscan documentation) 'cluster_selection_persistence': 0.0, ## (advanced) Minimum stability for a cluster to survive. Higher values = fewer, more stable clusters. 'd_clusterMerge': None, ## Distance below which clusters are merged. None defaults to d_cutoff (pruning threshold). 'rescue_noise': True, ## Post-HDBSCAN noise rescue: assign noise ROIs to nearby clusters via Kruskal DSU with cannot-link constraints. 'n_iter_violationCorrection': 6, ## Number of times to redo clustering sweep after removing violations 'split_intraSession_clusters': True, ## Whether or not to split clusters with ROIs from the same session 'alpha': 0.999, ## (advanced) Scalar applied to distance matrix in HDBSCAN (see hdbscan documentation) 'discard_failed_pruning': True, ## (advanced) Whether or not to set all ROIs that could be separated from clusters with ROIs from the same sessions to label=-1 'n_steps_clusterSplit': 100, ## (advanced) How finely to step through distances to remove violations }, 'sequential_hungarian': { 'thresh_cost': 0.6, ## Threshold for the cost matrix. Lower numbers result in more clusters. }, }, 'results_saving': { 'dir_save': None, ## Directory to save results to. If None, will not save. 'prefix_name_save': str(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")), ## Prefix to append to the saved files 'richfile_backend': 'zip', ## Backend for saving richfile data. Options: 'directory', 'sqlar', 'zip' (default), 'tar'. Archive backends produce a single file instead of a directory tree. 'gif_frame_rate': 10.0 ## Frame rate for any GIFs saved }, } ## Pipeline specific parameters ### prepare the different modules for each pipeline keys_pipeline = { 'tracking': [ 'general', 'data_loading', 'alignment', 'blurring', 'ROInet', 'SWT', 'similarity_graph', 'clustering', 'results_saving', ], 'classification_inference': [ 'general', 'data_loading', 'ROInet', 'results_saving', ], 'classification_training': [ 'general', 'data_loading', 'ROInet', 'results_saving', ], } ### prepare pipeline specific parameters if pipeline == 'tracking': out = copy.deepcopy({key: defaults[key] for key in keys_pipeline[pipeline]}) out['ROInet']['network'] = { 'download_method': 'check_local_first', ## Check to see if a model has already been downloaded to the location (will skip if hash matches) 'download_url': 'https://osf.io/x3fd2/download', ## URL of the model 'download_hash': '7a5fb8ad94b110037785a46b9463ea94', ## Hash of the model file 'forward_pass_version': 'latent', ## How the data is passed through the network } elif pipeline == 'classification_inference': out = copy.deepcopy({key: defaults[key] for key in keys_pipeline[pipeline]}) out['ROInet']['network'] = { 'download_method': 'check_local_first', ## Check to see if a model has already been downloaded to the location (will skip if hash matches) 'download_url': 'https://osf.io/c8m3b/download', ## URL of the model 'download_hash': '357a8d9b630ec79f3e015d0056a4c2d5', ## Hash of the model file 'forward_pass_version': 'head', ## How the data is passed through the network } elif pipeline == 'classification_training': out = copy.deepcopy({key: defaults[key] for key in keys_pipeline[pipeline]}) out['ROInet']['network'] = { 'download_method': 'check_local_first', ## Check to see if a model has already been downloaded to the location (will skip if hash matches) 'download_url': 'https://osf.io/c8m3b/download', ## URL of the model 'download_hash': '357a8d9b630ec79f3e015d0056a4c2d5', ## Hash of the model file 'forward_pass_version': 'head', ## How the data is passed through the network } else: raise NotImplementedError(f'pipeline={pipeline}, which is not implemented or not recognized. Should be one of: {list(keys_pipeline.keys())}') return out
[docs] def system_info(verbose: bool = False,) -> Dict: """ Checks and prints the versions of various important software packages. RH 2022 Args: verbose (bool): Whether to print the software versions. (Default is ``False``) Returns: (Dict): versions (Dict): Dictionary containing the versions of various software packages. """ ## Operating system and version import platform def try_fns(fn): try: return fn() except: return None fns = {key: val for key, val in platform.__dict__.items() if (callable(val) and key[0] != '_')} operating_system = {key: try_fns(val) for key, val in fns.items() if (callable(val) and key[0] != '_')} print(f'== Operating System ==: {operating_system["uname"]}') if verbose else None ## CPU info try: import cpuinfo import multiprocessing as mp cpu_info_raw = cpuinfo.get_cpu_info() cpu_n_cores = mp.cpu_count() cpu_brand = cpu_info_raw.get('brand_raw', 'Unknown') cpu_info = {'n_cores': cpu_n_cores, 'brand': cpu_brand} if 'flags' in cpu_info_raw: cpu_info['flags'] = 'omitted' except Exception as e: warnings.warn(f'RH WARNING: unable to get cpu info. Got error: {e}') cpu_info = 'ROICaT Error: Failed to get' print(f'== CPU Info ==: {cpu_info}') if verbose else None ## RAM import psutil ram = psutil.virtual_memory() print(f'== RAM ==: {ram}') if verbose else None ## User import getpass user = getpass.getuser() ## GPU try: gpu_info = helpers.list_available_devices() except Exception as e: warnings.warn(f'RH WARNING: unable to get gpu info. Got error: {e}') gpu_info = 'ROICaT Error: Failed to get' print(f'== GPU Info ==: {gpu_info}') if verbose else None ## Conda Environment import os if 'CONDA_DEFAULT_ENV' not in os.environ: conda_env = 'None' else: conda_env = os.environ['CONDA_DEFAULT_ENV'] print(f'== Conda Environment ==: {conda_env}') if verbose else None ## Python import sys python_version = sys.version.split(' ')[0] print(f'== Python Version ==: {python_version}') if verbose else None ## GCC import subprocess try: gcc_version = subprocess.check_output(['gcc', '--version']).decode('utf-8').split('\n')[0].split(' ')[-1] except Exception as e: warnings.warn(f'RH WARNING: unable to get gcc version. Got error: {e}') gcc_version = 'Faled to get' print(f'== GCC Version ==: {gcc_version}') if verbose else None ## PyTorch import torch torch_version = str(torch.__version__) print(f'== PyTorch Version ==: {torch_version}') if verbose else None ## CUDA if torch.cuda.is_available(): cuda_version = torch.version.cuda cudnn_version = torch.backends.cudnn.version() torch_devices = [f'device {i}: Name={torch.cuda.get_device_name(i)}, Memory={torch.cuda.get_device_properties(i).total_memory / 1e9} GB' for i in range(torch.cuda.device_count())] print(f"== CUDA Version ==: {cuda_version}, CUDNN Version: {cudnn_version}, Number of Devices: {torch.cuda.device_count()}, Devices: {torch_devices}, ") if verbose else None else: cuda_version = None cudnn_version = None torch_devices = None print('== CUDA is not available ==') if verbose else None ## all packages in environment import importlib.metadata pkgs_dict = {dist.metadata['Name'].lower(): dist.version for dist in importlib.metadata.distributions()} ## roicat import time roicat_version = importlib.metadata.version("roicat") roicat_fileDate = time.ctime(os.path.getctime(importlib.metadata.distribution("roicat").locate_file(''))) roicat_stuff = {'version': roicat_version, 'date_installed': roicat_fileDate} print(f'== ROICaT Version ==: {roicat_version}') if verbose else None print(f'== ROICaT date installed ==: {roicat_fileDate}') if verbose else None ## get datetime from datetime import datetime dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") versions = { 'datetime': dt, 'roicat': roicat_stuff, 'operating_system': operating_system, 'cpu_info': cpu_info, ## This is the slow one. 'user': user, 'ram': ram, 'gpu_info': gpu_info, 'conda_env': conda_env, 'python': python_version, 'gcc': gcc_version, 'torch': torch_version, 'cuda': cuda_version, 'cudnn': cudnn_version, 'torch_devices': torch_devices, 'pkgs': pkgs_dict, } def conv_str(obj): if isinstance(obj, (dict, collections.OrderedDict)): return {key: conv_str(val) for key, val in obj.items()} elif isinstance(obj, (list, tuple, set, frozenset)): return [conv_str(val) for val in obj] elif isinstance(obj, (int, float, bool, type(None))): return obj else: return str(obj) versions = conv_str(versions) return versions
[docs] def set_random_seed(seed=None, deterministic=False): """ Set random seed for reproducibility. RH 2023 Args: seed (int, optional): Random seed. If None, a random seed (spanning int32 integer range) is generated. deterministic (bool, optional): Whether to make packages deterministic. Returns: (int): seed (int): Random seed. """ ### random seed (note that optuna requires a random seed to be set within the pipeline) import numpy as np seed = int(np.random.randint(0, 2**31 - 1, dtype=np.uint32)) if seed is None else seed np.random.seed(seed) import torch torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) import random random.seed(seed) import cv2 cv2.setRNGSeed(seed) ## Make torch deterministic torch.use_deterministic_algorithms(deterministic) ## Make cudnn deterministic torch.backends.cudnn.deterministic = deterministic torch.backends.cudnn.benchmark = not deterministic return seed
[docs] class ROICaT_Module: """ Super class for ROICaT modules. RH 2023 Attributes: _system_info (object): System information. """ def __init__(self) -> None: """ Initializes the ROICaT_Module class by gathering system information. """ self._system_info = system_info() self.params = {} pass @property def serializable_dict(self) -> Dict[str, Any]: """ Returns a serializable dictionary that can be saved to disk. This method goes through all items in self.__dict__ and checks if they are serializable. If they are, add them to a dictionary to be returned. Returns: (Dict[str, Any]): serializable_dict (Dict[str, Any]): Dictionary containing serializable items. """ from functools import partial ## Go through all items in self.__dict__ and check if they are serializable. ### If they are, add them to a dictionary to be returned. import pickle ## Define a list of libraries and classes that are allowed to be serialized. allowed_libraries = [ 'roicat', 'builtins', 'collections', 'datetime', 'itertools', 'math', 'numbers', 'os', 'pathlib', 'string', 'time', 'numpy', 'scipy', 'sklearn', ] def is_library_allowed(obj): try: try: module_name = obj.__module__.split('.')[0] except: success = False try: module_name = obj.__class__.__module__.split('.')[0] except: success = False except: success = False else: ## Check if the module_name is in the allowed_libraries list. if module_name in allowed_libraries: success = True else: success = False return success def make_serializable_dict(obj, depth=0, max_depth=100, name=None): """ Recursively go through all items in self.__dict__ and check if they are serializable. """ # print(name) msd_partial = partial(make_serializable_dict, depth=depth+1, max_depth=max_depth) if depth > max_depth: raise Exception(f'RH ERROR: max_depth of {max_depth} reached with object: {obj}') serializable_dict = {} if hasattr(obj, '__dict__') and is_library_allowed(obj): for key, val in obj.__dict__.items(): try: serializable_dict[key] = msd_partial(val, name=key) except: pass elif isinstance(obj, (list, tuple, set, frozenset)): serializable_dict = [msd_partial(v, name=f'{name}_{ii}') for ii,v in enumerate(obj)] elif isinstance(obj, dict): serializable_dict = {k: msd_partial(v, name=f'{name}_{k}') for k,v in obj.items()} else: try: assert is_library_allowed(obj), f'RH ERROR: object {obj} is not serializable' pickle.dumps(obj) except: return {'__repr__': repr(obj)} if hasattr(obj, '__repr__') else {'__str__': str(obj)} if hasattr(obj, '__str__') else None serializable_dict = obj return serializable_dict serializable_dict = make_serializable_dict(self, depth=0, max_depth=100, name='self') return serializable_dict # def save( # self, # path_save: Union[str, Path], # save_as_serializable_dict: bool = False, # allow_overwrite: bool = False, # ) -> None: # """ # Saves Data_roicat object to pickle file. # Args: # path_save (Union[str, pathlib.Path]): # Path to save pickle file. # save_as_serializable_dict (bool): # An archival-type format that is easy to load data from, but typically # cannot be used to re-instantiate the object. If ``True``, save the object # as a serializable dictionary. If ``False``, save the object as a Data_roicat # object. (Default is ``False``) # allow_overwrite (bool): # If ``True``, allow overwriting of existing file. (Default is ``False``) # """ # from pathlib import Path # ## Check if file already exists # if not allow_overwrite: # assert not Path(path_save).exists(), f"RH ERROR: File already exists: {path_save}. Set allow_overwrite=True to overwrite." # helpers.pickle_save( # obj=self.serializable_dict if save_as_serializable_dict else self, # filepath=path_save, # mkdir=True, # allow_overwrite=allow_overwrite, # ) # print(f"Saved Data_roicat as a pickled object to {path_save}.") if self._verbose else None # def load( # self, # path_load: Union[str, Path], # ) -> None: # """ # Loads attributes from a Data_roicat object from a pickle file. # Args: # path_load (Union[str, Path]): # Path to the pickle file. # Note: # After calling this method, the attributes of this object are updated with those # loaded from the pickle file. If an object in the pickle file is a dictionary, # the object's attributes are set directly from the dictionary. Otherwise, if # the object in the pickle file has a 'import_from_dict' method, it is used # to load attributes. If it does not, the attributes are directly loaded from # the object's `__dict__` attribute. # Example: # .. highlight:: python # .. code-block:: python # obj = Data_roicat() # obj.load('/path/to/pickle/file') # """ # from pathlib import Path # assert Path(path_load).exists(), f"RH ERROR: File does not exist: {path_load}." # obj = helpers.pickle_load(path_load) # assert isinstance(obj, (type(self), dict)), f"RH ERROR: Loaded object is not a Data_roicat object or dictionary. Loaded object is of type {type(obj)}." # if isinstance(obj, dict): # ## Set attributes from dict # ### If the subclass has a load_from_dict method, use that. # if hasattr(self, 'import_from_dict'): # self.import_from_dict(obj) # else: # for key, val in obj.items(): # setattr(self, key, val) # else: # ## Set attributes from object # for key, val in obj.__dict__.items(): # setattr(self, key, val) # print(f"Loaded Data_roicat object from {path_load}.") if self._verbose else None def _locals_to_params( self, locals_dict: Dict[str, Any], keys: List[str], ) -> None: """ Returns a dictionary of the local variables with the specified keys. Args: locals_dict (Dict[str, Any]): Dictionary of local variables. keys (List[str]): List of keys to extract from the local variables. """ def safe_getitem(d, key): try: return d[key] except KeyError: warnings.warn(f'RH WARNING: key={key} not found in locals_dict. Skipping.') return {key: safe_getitem(locals_dict, key) for key in keys}
[docs] class RichFile_ROICaT(rf.RichFile): """ RichFile subclass with ROICaT-specific type registrations (numpy arrays, scipy sparse matrices, torch tensors, optuna studies, pandas DataFrames, etc.). Args: path (Optional[Union[str, Path]]): Path to save/load the richfile. check (Optional[bool]): Whether to perform validation checks. safe_save (Optional[bool]): Whether to use atomic save with temporary file. backend (Optional[str]): Storage backend. One of: * ``'auto'``: auto-detect from existing path, or default to ``'directory'`` for new saves. * ``'directory'``: classic richfile directory tree. * ``'sqlar'``: single-file SQLite archive (``.sqlar``). * ``'zip'``: single-file ZIP archive (``.zip``, stored/no compression). * ``'tar'``: single-file plain TAR archive (``.tar``). """ def __init__( self, path: Optional[Union[str, Path]] = None, check: Optional[bool] = True, safe_save: Optional[bool] = True, backend: Optional[str] = "auto", ): super().__init__(path=path, check=check, safe_save=safe_save, backend=backend) ## NUMPY ARRAY import numpy as np def save_npy_array( obj: np.ndarray, path: Union[str, Path], **kwargs, ) -> None: """ Saves a NumPy array to the given path. """ np.save(path, obj, **kwargs) def load_npy_array( path: Union[str, Path], **kwargs, ) -> np.ndarray: """ Loads an array from the given path. """ return np.load(path, **kwargs) ## SCIPY SPARSE MATRIX import scipy.sparse def save_sparse_array( obj: Union[scipy.sparse.spmatrix, scipy.sparse.sparray], path: Union[str, Path], **kwargs, ) -> None: """ Saves a SciPy sparse matrix/array to the given path. """ scipy.sparse.save_npz(path, obj, **kwargs) def load_sparse_array( path: Union[str, Path], **kwargs, ) -> scipy.sparse.sparray: """ Loads a sparse array from the given path. """ return scipy.sparse.load_npz(path, **kwargs) ## JSON DICT import collections import json def save_json_dict( obj: collections.UserDict, path: Union[str, Path], **kwargs, ) -> None: """ Saves a dictionary to the given path. """ with open(path, 'w') as f: json.dump(dict(obj), f, **kwargs) def load_json_dict( path: Union[str, Path], **kwargs, ) -> collections.UserDict: """ Loads a dictionary from the given path. """ with open(path, 'r') as f: return JSON_Dict(json.load(f, **kwargs)) ## JSON LIST def save_json_list( obj: collections.UserList, path: Union[str, Path], **kwargs, ) -> None: """ Saves a list to the given path. """ with open(path, 'w') as f: json.dump(list(obj), f, **kwargs) def load_json_list( path: Union[str, Path], **kwargs, ) -> collections.UserList: """ Loads a list from the given path. """ with open(path, 'r') as f: return JSON_List(json.load(f, **kwargs)) ## OPTUNA STUDY import optuna import pickle ## load and save functions for optuna study def save_optuna_study( obj: optuna.study.Study, path: Union[str, Path], **kwargs, ) -> None: """ Saves an Optuna study to the given path. """ with open(path, 'wb') as f: pickle.dump(obj, f, **kwargs) def load_optuna_study( path: Union[str, Path], **kwargs, ) -> optuna.study.Study: """ Loads an Optuna study from the given path. """ with open(path, 'rb') as f: return pickle.load(f, **kwargs) ## TORCH TENSOR import torch def save_torch_tensor( obj: torch.Tensor, path: Union[str, Path], **kwargs, ) -> None: """ Saves a PyTorch tensor to the given path as a NumPy array. """ np.save(path, obj.detach().cpu().numpy(), **kwargs) def load_torch_tensor( path: Union[str, Path], **kwargs, ) -> torch.Tensor: """ Loads a PyTorch tensor from the given path. """ return torch.from_numpy(np.load(path, **kwargs)) ## REPR def save_repr( obj: object, path: Union[str, Path], **kwargs, ) -> None: """ Saves the repr of an object to the given path. """ with open(path, 'w') as f: f.write(repr(obj)) def load_repr( path: Union[str, Path], **kwargs, ) -> object: """ Loads the repr of an object from the given path. """ with open(path, 'r') as f: return f.read() ## HDBSCAN OBJECT (fast_hdbscan or legacy) import fast_hdbscan _hdbscan_class = fast_hdbscan.HDBSCAN def save_hdbscan( obj, path: Union[str, Path], **kwargs, ) -> None: """ Save a fast_hdbscan.HDBSCAN object by extracting serializable attributes into a dict of numpy arrays and JSON scalars. """ import json attrs = {} ## Use vars(obj) instead of dir(obj) to avoid triggering sklearn's ## __dir__ which calls hasattr on every attribute, including broken ## properties like condensed_tree_ that raise NameError in fast_hdbscan. ## Also include sklearn get_params() for constructor parameters. all_attrs = dict(vars(obj)) try: all_attrs.update(obj.get_params()) except Exception: pass for attr, val in sorted(all_attrs.items()): if attr.startswith('_'): continue if callable(val): continue if val is None: attrs[attr] = None elif isinstance(val, np.ndarray): attrs[attr] = {'_type': 'ndarray', 'data': val.tolist(), 'dtype': str(val.dtype)} elif scipy.sparse.issparse(val): csr = val.tocsr() attrs[attr] = { '_type': 'sparse', 'data': csr.data.tolist(), 'indices': csr.indices.tolist(), 'indptr': csr.indptr.tolist(), 'shape': list(csr.shape), } elif isinstance(val, (int, float, str, bool)): attrs[attr] = val elif isinstance(val, np.integer): attrs[attr] = int(val) elif isinstance(val, np.floating): attrs[attr] = float(val) ## Skip non-serializable objects (CondensedTree, SingleLinkageTree) Path(path).parent.mkdir(parents=True, exist_ok=True) with open(path, 'w') as f: json.dump(attrs, f) def load_hdbscan( path: Union[str, Path], **kwargs, ) -> object: """ Load a fast_hdbscan.HDBSCAN-like dict from JSON. Returns a dict of the stored attributes (not a live HDBSCAN object). """ import json with open(path, 'r') as f: content = f.read() try: attrs = json.loads(content) except json.JSONDecodeError: ## Old format: file contains repr string, not JSON return content ## Reconstruct numpy arrays and sparse matrices for key, val in attrs.items(): if isinstance(val, dict) and val.get('_type') == 'ndarray': attrs[key] = np.array(val['data'], dtype=val['dtype']) elif isinstance(val, dict) and val.get('_type') == 'sparse': attrs[key] = scipy.sparse.csr_array( (np.array(val['data']), np.array(val['indices']), np.array(val['indptr'])), shape=tuple(val['shape']), ) return attrs ## SCIPY OPTIMIZE RESULT import scipy.optimize def save_optimize_result( obj: scipy.optimize.OptimizeResult, path: Union[str, Path], **kwargs, ) -> None: """ Saves a scipy.optimize.OptimizeResult as JSON. Extracts standard fields into a JSON-serializable dict. """ d = {} for key in ('x', 'fun', 'nfev', 'nit', 'success', 'message'): if key in obj: val = obj[key] if isinstance(val, np.ndarray): d[key] = val.tolist() elif isinstance(val, (np.integer,)): d[key] = int(val) elif isinstance(val, (np.floating,)): d[key] = float(val) elif isinstance(val, np.bool_): d[key] = bool(val) else: d[key] = val with open(path, 'w') as f: json.dump(d, f) def load_optimize_result( path: Union[str, Path], **kwargs, ) -> scipy.optimize.OptimizeResult: """ Loads a scipy.optimize.OptimizeResult from JSON. """ with open(path, 'r') as f: d = json.load(f) if 'x' in d and isinstance(d['x'], list): d['x'] = np.array(d['x']) return scipy.optimize.OptimizeResult(**d) ## PANDAS DATAFRAME import pandas as pd def save_pandas_dataframe( obj: pd.DataFrame, path: Union[str, Path], **kwargs, ) -> None: """ Saves a Pandas DataFrame to the given path. """ ## Save as a CSV file obj.to_csv(path, index=True, **kwargs) def load_pandas_dataframe( path: Union[str, Path], **kwargs, ) -> pd.DataFrame: """ Loads a Pandas DataFrame from the given path. """ ## Load as a CSV file return pd.read_csv(path, index_col=0, **kwargs) roicat_module_tds = [rf.functions.Container( type_name=type_name, object_class=object_class, suffix="roicat", library="roicat", versions_supported=[">=1.1", "<2"], ) for type_name, object_class in [ # ("data_suite2p", data_importing.Data_suite2p), # ("data_caiman", data_importing.Data_caiman), # ("data_roiextractors", data_importing.Data_roiextractors), # ("data_roicat", data_importing.Data_roicat), # ("aligner", alignment.Aligner), # ("blurrer", blurring.ROI_Blurrer), # ("roinet", ROInet.ROInet_embedder), # ("swt", scatteringWaveletTransformer.SWT), # ("similarity_graph", similarity_graph.ROI_graph), # ("clusterer", clustering.Clusterer), ("toeplitz_conv", helpers.Toeplitz_convolution2d), ("convergence_checker_optuna", helpers.Convergence_checker_optuna), ("image_alignment_checker", helpers.ImageAlignmentChecker), ]] ## sparse_convolution.Toeplitz_convolution2d: store as JSON dict ## (its dtype attr is a type object that richfile can't serialize as-is) import sparse_convolution def _save_sparse_conv(obj, path, **kwargs): d = { 'x_shape': list(obj.x_shape), 'k': obj.k.tolist(), 'mode': obj.mode, 'method': obj.method, 'dtype': np.dtype(obj.dtype).str, } with open(path, 'w') as f: json.dump(d, f) def _load_sparse_conv(path, **kwargs): with open(path, 'r') as f: d = json.load(f) return sparse_convolution.Toeplitz_convolution2d( x_shape=tuple(d['x_shape']), k=np.array(d['k']), mode=d['mode'], method=d['method'], dtype=np.dtype(d['dtype']), ) sparse_conv_dict = { 'type_name': 'sparse_conv', 'object_class': sparse_convolution.Toeplitz_convolution2d, 'suffix': 'json', 'library': 'sparse_convolution', 'versions_supported': [], 'function_save': _save_sparse_conv, 'function_load': _load_sparse_conv, } # roicat_module_tds = [] ## SIMILARITY METRIC (dataclass → JSON) from .tracking.similarity_graph import SimilarityMetric def save_similarity_metric( obj: SimilarityMetric, path: Union[str, Path], **kwargs, ) -> None: """Saves a SimilarityMetric dataclass as JSON via to_dict().""" with open(path, 'w') as f: json.dump(obj.to_dict(), f) def load_similarity_metric( path: Union[str, Path], **kwargs, ) -> SimilarityMetric: """Loads a SimilarityMetric from a JSON dict.""" with open(path, 'r') as f: d = json.load(f) ## Convert power_bounds from list back to tuple (JSON round-trip) if 'power_bounds' in d and isinstance(d['power_bounds'], list): d['power_bounds'] = tuple(d['power_bounds']) return SimilarityMetric.from_dict(d) type_dicts = [ { "type_name": "numpy_array", "function_load": load_npy_array, "function_save": save_npy_array, "object_class": np.ndarray, "suffix": "npy", "library": "numpy", "versions_supported": [], }, { "type_name": "numpy_scalar", "function_load": load_npy_array, "function_save": save_npy_array, "object_class": np.number, "suffix": "npy", "library": "numpy", "versions_supported": [], }, { "type_name": "scipy_sparse_array", "function_load": load_sparse_array, "function_save": save_sparse_array, "object_class": scipy.sparse.spmatrix, "suffix": "npz", "library": "scipy", "versions_supported": [], }, ## scipy >= 1.14 returns csr_array instead of csr_matrix from many ## operations. csr_array inherits from sparray, not spmatrix. { "type_name": "scipy_sparray", "function_load": load_sparse_array, "function_save": save_sparse_array, "object_class": scipy.sparse.sparray, "suffix": "npz", "library": "scipy", "versions_supported": [], }, { "type_name": "json_dict", "function_load": load_json_dict, "function_save": save_json_dict, "object_class": JSON_Dict, "suffix": "json", "library": "python", "versions_supported": [], }, { "type_name": "json_list", "function_load": load_json_list, "function_save": save_json_list, "object_class": JSON_List, "suffix": "json", "library": "python", "versions_supported": [], }, { "type_name": "optuna_study", "function_load": load_optuna_study, "function_save": save_optuna_study, "object_class": optuna.study.Study, "suffix": "optuna", "library": "optuna", "versions_supported": [], }, { "type_name": "torch_tensor", "function_load": load_torch_tensor, "function_save": save_torch_tensor, "object_class": torch.Tensor, "suffix": "npy", "library": "torch", "versions_supported": [], }, { "type_name": "model_swt", "function_load": load_repr, "function_save": save_repr, "object_class": Model_SWT, "suffix": "swt", "library": "onnx2torch", "versions_supported": [], }, { "type_name": "torch_module", "function_load": load_repr, "function_save": save_repr, "object_class": torch.nn.Module, "suffix": "torch_module", "library": "torch", "versions_supported": [], }, { "type_name": "torch_sequence", "function_load": load_repr, "function_save": save_repr, "object_class": torch.nn.Sequential, "suffix": "torch_sequence", "library": "torch", "versions_supported": [], }, { "type_name": "torch_dataset", "function_load": load_repr, "function_save": save_repr, "object_class": torch.utils.data.Dataset, "suffix": "torch_dataset", "library": "torch", "versions_supported": [], }, { "type_name": "torch_dataloader", "function_load": load_repr, "function_save": save_repr, "object_class": torch.utils.data.DataLoader, "suffix": "torch_dataloader", "library": "torch", "versions_supported": [], }, *([{ "type_name": "hdbscan", "function_load": load_hdbscan, "function_save": save_hdbscan, "object_class": _hdbscan_class, "suffix": "json", "library": "fast_hdbscan", "versions_supported": [], }] if _hdbscan_class is not None else []), { "type_name": "pandas_dataframe", "function_load": load_pandas_dataframe, "function_save": save_pandas_dataframe, "object_class": pd.DataFrame, "suffix": "csv", "library": "pandas", "versions_supported": [], }, { "type_name": "scipy_optimize_result", "function_load": load_optimize_result, "function_save": save_optimize_result, "object_class": scipy.optimize.OptimizeResult, "suffix": "json", "library": "scipy", "versions_supported": [], }, { "type_name": "similarity_metric", "function_load": load_similarity_metric, "function_save": save_similarity_metric, "object_class": SimilarityMetric, "suffix": "json", "library": "roicat", "versions_supported": [], }, ] + [t.get_property_dict() for t in roicat_module_tds] + [sparse_conv_dict] [self.register_type_from_dict(d) for d in type_dicts]
###################################### ######## CUSTOM DATA CLASSES ######### ######################################
[docs] class JSON_Dict(dict): def __init__(self, *args, **kwargs): super(JSON_Dict, self).__init__(*args, **kwargs)
[docs] class JSON_List(list): def __init__(self, *args, **kwargs): super(JSON_List, self).__init__(*args, **kwargs)
## Wrapper for SWT
[docs] class Model_SWT(torch.nn.Module): def __init__(self, model: torch.nn.Module): super(Model_SWT, self).__init__() self.add_module('model', model)
[docs] def forward(self, x): return self.model(x)
[docs] def make_session_bool(n_roi: np.ndarray,) -> np.ndarray: """ Generates a boolean array representing ROIs (Region Of Interest) per session from an array of ROI counts. Args: n_roi (np.ndarray): Array representing the number of ROIs per session. *shape*: *(n_sessions,)* Returns: (np.ndarray): session_bool (np.ndarray): Boolean array of shape *(n_roi_total, n_session)* where each column represents a session and each row corresponds to an ROI. Example: .. highlight:: python .. code-block:: python n_roi = np.array([3, 4, 2]) session_bool = make_session_bool(n_roi) """ n_roi_total = np.sum(n_roi) r = np.arange(n_roi_total, dtype=np.int64) n_roi_cumsum = np.concatenate([[0], np.cumsum(n_roi)]) session_bool = np.vstack([(b_lower <= r) * (r < b_upper) for b_lower, b_upper in zip(n_roi_cumsum[:-1], n_roi_cumsum[1:])]).T return session_bool
[docs] def split_iby_session( x: Any, n_roi_per_session: Union[np.ndarray, List[int]], ): """ Splits an array or iterable into a list of arrays or iterables based on the number of ROIs per session. Args: arr (Any): Array to split. n_roi_per_session (Union[np.ndarray, List[int]]): Number of ROIs per session. Returns: (List[Any]): List of arrays split by session. """ return [x[sum(n_roi_per_session[:ii]):sum(n_roi_per_session[:ii+1])] for ii in range(len(n_roi_per_session))]
########################################################################################################################## ############################################### UCID handling ############################################################ ##########################################################################################################################
[docs] def check_dataStructure__list_ofListOrArray_ofDtype( lolod: Union[List[List[Union[int, float]]], List[np.ndarray]], dtype: Type = np.int64, fix: bool = True, verbose: bool = True, ) -> Union[List[List[Union[int, float]]], List[np.ndarray]]: """ Verifies and optionally corrects the data structure of 'lolod' (list of list of dtype). The structure should be a list of lists of dtypes or a list of numpy arrays of dtypes. Args: lolod (Union[List[List[Union[int, float]]], List[np.ndarray]]): * The data structure to check. It should be a list of lists of dtypes or a list of numpy arrays of dtypes. dtype (Type): * The expected dtype of the elements in 'lolod'. (Default is ``np.int64``) fix (bool): * If ``True``, attempts to correct the data structure if it is not as expected. The corrections are as follows: \n * If 'lolod' is an array, it will be cast to [lolod] * If 'lolod' is a numpy object, it will be cast to [np.array(lolod, dtype=dtype)] * If 'lolod' is a list of lists of numbers (int or float), it will be cast to [np.array(lod, dtype=dtype) for lod in lolod] * If 'lolod' is a list of arrays of wrong dtype, it will be cast to [np.array(lod, dtype=dtype) for lod in lolod] \n * If ``False``, raises an error if the structure is not as expected. (Default is ``True``) verbose (bool): * If ``True``, prints warnings when the structure is not as expected and is corrected. (Default is ``True``) Returns: (Union[List[List[Union[int, float]]], List[np.ndarray]]): lolod (Union[List[List[Union[int, float]]], List[np.ndarray]]): The verified or corrected data structure. """ ## switch case for if it is a list or np.ndarray if isinstance(lolod, list): ## switch case for if the elements are lists or np.ndarray or numbers (int or float) or dtypes if all([isinstance(lod, list) for lod in lolod]): ## switch case for if the elements are numbers (int or float) or dtype or other if all([all([isinstance(l, (int, float, np.integer, np.floating)) for l in lod]) for lod in lolod]): if fix: print(f'ROICaT WARNING: lolod is a list of lists of numbers (int or float). Converting to np.ndarray.') if verbose else None lolod = [np.array(lod, dtype=dtype) for lod in lolod] else: raise ValueError(f'ROICaT ERROR: lolod is a list of lists of numbers (int or float).') elif all([all([isinstance(l, dtype) for l in lod]) for lod in lolod]): pass else: raise ValueError(f'ROICaT ERROR: lolod is a list of lists, but the elements are not all numbers (int or float) or dtype.') elif all([isinstance(lod, np.ndarray) for lod in lolod]): ## switch case for if the elements are numbers (any non-object numpy dtype) or dtype or other if all([all([np.issubdtype(lod.dtype, dtype) for lod in lolod])]): pass if all([all([not np.issubdtype(lod.dtype, np.object_) for lod in lolod])]): if fix: print(f'ROICaT WARNING: lolod is a list of np.ndarray of numbers (int or float). Converting to np.ndarray.') if verbose else None lolod = [np.array(lod, dtype=dtype) for lod in lolod] else: raise ValueError(f'ROICaT ERROR: lolod is a list of np.ndarray of numbers (int or float).') else: raise ValueError(f'ROICaT ERROR: lolod is a list of np.ndarray, but the elements are not all numbers (int or float) or dtype.') elif all([isinstance(lod, (int, float)) for lod in lolod]): if fix: print(f'ROICaT WARNING: lolod is a list of numbers (int or float). Converting to np.ndarray.') if verbose else None lolod = [np.array(lod, dtype=dtype) for lod in lolod] else: raise ValueError(f'ROICaT ERROR: lolod is a list of numbers (int or float).') elif all([isinstance(lod, dtype) for lod in lolod]): if fix: print(f'ROICaT WARNING: lolod is a list of dtype. Converting to np.ndarray.') if verbose else None lolod = [np.array(lolod, dtype=dtype)] else: raise ValueError(f'ROICaT ERROR: lolod is a list of dtype.') else: raise ValueError(f'ROICaT ERROR: lolod is a list, but the elements are not all lists or np.ndarray or numbers (int or float).') elif isinstance(lolod, np.ndarray): ## switch case for if the elements are numbers (any non-object numpy dtype) or dtype or other if np.issubdtype(lolod.dtype, dtype): if fix: print(f'ROICaT WARNING: lolod is a np.ndarray of dtype. Converting to list of np.ndarray of dtype.') if verbose else None lolod = [lolod] elif not np.issubdtype(lolod.dtype, np.object_): if fix: print(f'ROICaT WARNING: lolod is a np.ndarray of numbers (int or float). Converting to list of np.ndarray of dtype.') if verbose else None lolod = [np.array(lolod, dtype=dtype)] else: raise ValueError(f'ROICaT ERROR: lolod is a np.ndarray of numbers (int or float).') else: raise ValueError(f'ROICaT ERROR: lolod is a np.ndarray, but the elements are not all numbers (int or float) or dtype.') else: raise ValueError(f'ROICaT ERROR: lolod is not a list or np.ndarray.') return lolod
[docs] def mask_UCIDs_with_iscell( ucids: List[Union[List[int], np.ndarray]], iscell: List[Union[List[bool], np.ndarray]] ) -> List[Union[List[int], np.ndarray]]: """ Masks the UCIDs with the **iscell** array. If ``iscell`` is False, then the UCID is set to -1. Args: ucids (List[Union[List[int], np.ndarray]]): List of lists of UCIDs for each session.\n Shape outer list: *(n_sessions,)*\n Shape inner list: *(n_roi_in_session,)* iscell (List[Union[List[bool], np.ndarray]]): List of lists of boolean indicators for each UCID.\n ``True`` means that ROI is a cell, ``False`` means that ROI is not a cell.\n Shape outer list: *(n_sessions,)*\n Shape inner list: *(n_roi_in_session,)* Returns: (List[Union[List[int], np.ndarray]]): ucids_out (List[Union[List[int], np.ndarray]]): Masked list of lists of UCIDs. Elements that are not cells are set to -1 in each session. """ ucids_out = copy.deepcopy(ucids) ucids_out = check_dataStructure__list_ofListOrArray_ofDtype( lolod=ucids_out, dtype=np.int64, fix=True, verbose=False, ) iscell = check_dataStructure__list_ofListOrArray_ofDtype( lolod=iscell, dtype=bool, fix=True, verbose=False, ) n_sesh = len(ucids) for i_sesh in range(n_sesh): ucids_out[i_sesh][~iscell[i_sesh]] = -1 return ucids_out
[docs] def mask_UCIDs_by_label( ucids: List[Union[List[int], np.ndarray]], labels: Union[List[int], np.ndarray], ) -> List[Union[List[int], np.ndarray]]: """ Sets labels in the UCIDs to -1 if they are not present in the **labels** array.\n RH 2024 Args: ucids (List[Union[List[int], np.ndarray]]): List of lists of UCIDs for each session.\n Shape outer list: *(n_sessions,)*\n Shape inner list: *(n_roi_in_session,)* labels (Union[List[int], np.ndarray]): Array of labels to keep. All other labels are set to -1. Shape: *(n_labels,)* Returns: (List[Union[List[int], np.ndarray]]): ucids_out (List[Union[List[int], np.ndarray]]): Masked list of lists of UCIDs. Elements that are not in the **labels** array are set to -1 in each session. Example: .. highlight:: python .. code-block:: python ucids = [[1, 2, 3], [2, -1, 4], [3, 0, 5]] labels = [2, 3] ucids_out = mask_UCIDs_by_label(ucids, labels) # ucids_out = [[-1, 2, 3], [2, -1, -1], [3, -1, -1]] """ ucids_out = copy.deepcopy(ucids) ucids_out = check_dataStructure__list_ofListOrArray_ofDtype( lolod=ucids_out, dtype=np.int64, fix=True, verbose=False, ) labels = np.array(labels, dtype=np.int64) iscell = [np.isin(u_sesh, labels) for u_sesh in ucids_out] ucids_out = mask_UCIDs_with_iscell(ucids_out, iscell) return ucids_out
[docs] def discard_UCIDs_with_fewer_matches( ucids: List[Union[List[int], np.ndarray]], n_sesh_thresh: Union[int, str] = 'all', verbose: bool = True ) -> List[Union[List[int], np.ndarray]]: """ Discards UCIDs that do not appear in at least **n_sesh_thresh** sessions. If ``n_sesh_thresh='all'``, then only UCIDs that appear in all sessions are kept. Args: ucids (List[Union[List[int], np.ndarray]]): List of lists of UCIDs for each session. n_sesh_thresh (Union[int, str]): Number of sessions that a UCID must appear in to be kept. If ``'all'``, then only UCIDs that appear in all sessions are kept. (Default is ``'all'``) verbose (bool): If ``True``, print verbose output. (Default is ``True``) Returns: (List[Union[List[int], np.ndarray]]): ucids_out (List[Union[List[int], np.ndarray]]): List of lists of UCIDs with UCIDs that do not appear in at least **n_sesh_thresh** sessions set to -1. """ ucids_out = copy.deepcopy(ucids) ucids_out = check_dataStructure__list_ofListOrArray_ofDtype( lolod=ucids_out, dtype=np.int64, fix=True, verbose=False, ) n_sesh = len(ucids) n_sesh_thresh = n_sesh if n_sesh_thresh == 'all' else n_sesh_thresh assert isinstance(n_sesh_thresh, int) ucids_unique = np.unique(np.concatenate(ucids_out, axis=0)) ucids_inAllSesh = [u for u in ucids_unique if np.array([np.isin(u, u_sesh) for u_sesh in ucids_out]).sum() >= n_sesh_thresh] if verbose: fraction = (np.unique(ucids_inAllSesh) >= 0).sum() / (ucids_unique >= 0).sum() print(f'INFO: {fraction*100:.2f}% of UCIDs that appear in at least {n_sesh_thresh} sessions.') ucids_out = [[val * np.isin(val, ucids_inAllSesh) - np.logical_not(np.isin(val, ucids_inAllSesh)) for val in u] for u in ucids_out] return ucids_out
[docs] def squeeze_UCID_labels( ucids: List[Union[List[int], np.ndarray]], return_array: bool = False, ) -> List[Union[List[int], np.ndarray]]: """ Squeezes the UCID labels. Finds all the unique UCIDs across all sessions, then removes spaces in the UCID labels by mapping the unique UCIDs to new values. Output UCIDs are contiguous integers starting at 0, and maintains elements with UCID=-1. Args: ucids (List[Union[List[int], np.ndarray]]): List of lists of UCIDs for each session. return_array (bool): If ``True``, then the output will be a numpy array. (Default is ``False``) Returns: (List[Union[List[int], np.ndarray]]): ucids_out (List[Union[List[int], np.ndarray]]): List of lists of UCIDs with UCIDs that do not appear in at least **n_sesh_thresh** sessions set to -1. """ ucids_out = copy.deepcopy(ucids) ucids_out = check_dataStructure__list_ofListOrArray_ofDtype( lolod=ucids_out, dtype=np.int64, fix=True, verbose=False, ) uniques_all = np.unique(np.concatenate(ucids_out, axis=0)) uniques_all = np.sort(uniques_all[uniques_all >= 0]) ## make a mapping of the unique values to new values # mapping = {old: new for old, new in zip(uniques_all, helpers.squeeze_integers(uniques_all))} mapping = {old: new for old, new in zip(uniques_all, np.arange(len(uniques_all)))} mapping.update({-1: -1}) ## apply the mapping to the data n_sesh = len(ucids_out) for i_sesh in range(n_sesh): ucids_out[i_sesh] = [int(mapping[val]) for val in ucids_out[i_sesh]] if not return_array: return ucids_out else: return [np.array(u) for u in ucids_out]
[docs] def match_arrays_with_ucids( arrays: Union[np.ndarray, List[np.ndarray]], ucids: Union[List[np.ndarray], List[List[int]]], return_indices: bool = False, squeeze: bool = False, force_sparse: bool = False, prog_bar: bool = False, ) -> List[Union[np.ndarray, scipy.sparse.lil_array]]: """ Matches the indices of the arrays using the UCIDs. Array indices with UCIDs corresponding to -1 are set to ``np.nan``. This is useful for aligning Fluorescence and Spiking data across sessions using UCIDs. Args: arrays (Union[np.ndarray, List[np.ndarray]]): List of numpy arrays for each session. Matching is done along the first dimension. ucids (Union[List[np.ndarray], List[List[int]]]): List of lists of UCIDs for each session. return_indices (bool): If ``True``, then the indices of the UCIDs will also be returned. The indices will be of dtype np.float32 because it may contain NaNs. (Default is ``False``) squeeze (bool): If ``True``, then UCIDs are squeezed to be contiguous integers. (Default is ``False``) force_sparse (bool): If ``True``, then the output will be a list of sparse matrices. (Default is ``False``) prog_bar (bool): If ``True``, then a progress bar will be displayed. (Default is ``False``) Returns: (List[Union[np.ndarray, scipy.sparse.lil_array]]): arrays_out (List[Union[np.ndarray, scipy.sparse.lil_array]]): List of arrays for each session. Array indices with UCIDs corresponding to -1 are set to ``np.nan``. Each array will have shape: *(n_ucids if squeeze==True OR max_ucid if squeeze==False, *array.shape[1:])*. UCIDs will be used as the index of the first dimension. """ import scipy.sparse arrays = [arrays] if not isinstance(arrays, list) else arrays ucids_tu = check_dataStructure__list_ofListOrArray_ofDtype( lolod=ucids, dtype=np.int64, fix=True, verbose=False, ) ## Error if dtype is not NaN compatible if not np.issubdtype(arrays[0].dtype, np.floating): raise ValueError(f'ROICaT ERROR: This function requires inputs to be of a dtype that is compatible with NaNs, like np.floating types: np.float32, np.float64, etc.') ## Squeeze UCIDs ucids_tu = squeeze_UCID_labels(ucids_tu) if squeeze else ucids_tu # max_ucid = (np.unique(np.concatenate(ucids_tu, axis=0)) >= 0).max() max_ucid = (np.unique(np.concatenate(ucids_tu, axis=0))).max().astype(int) + 1 dicts_ucids = [{u: i for i, u in enumerate(u_sesh)} for u_sesh in ucids_tu] ## make ndarrays filled with np.nan for each session if isinstance(arrays[0], np.ndarray) and not force_sparse: arrays_out = [np.full((max_ucid, *a.shape[1:]), np.nan, dtype=arrays[0].dtype) for a in arrays] elif scipy.sparse.issparse(arrays[0]) or force_sparse: arrays_out = [scipy.sparse.lil_array((max_ucid, *a.shape[1:]), dtype=a.dtype) for a in arrays] else: raise ValueError(f'ROICaT ERROR: arrays[0] is not a numpy array or scipy.sparse matrix.') ## fill in the arrays with the data n_sesh = len(arrays) for i_sesh in tqdm(range(n_sesh), disable=not prog_bar): for u, idx in dicts_ucids[i_sesh].items(): if u >= 0: arrays_out[i_sesh][u] = arrays[i_sesh][idx] # Use the original array lengths so we can recover the correct source indices # even when the per-session arrays contain more ROIs than there are UCIDs. lens_arrays = [a.shape[0] if hasattr(a, 'shape') else len(a) for a in arrays] if not return_indices: return arrays_out else: return arrays_out, match_arrays_with_ucids( arrays=[np.arange(lens_arrays[ii], dtype=np.float32) for ii, a in enumerate(arrays)], ucids=ucids, return_indices=False, squeeze=squeeze, force_sparse=False, prog_bar=False, )
[docs] def match_arrays_with_ucids_inverse( arrays: Union[np.ndarray, List[np.ndarray]], ucids: Union[List[np.ndarray], List[List[int]]], unsqueeze: bool = True, ) -> List[Union[np.ndarray, scipy.sparse.lil_array]]: """ Inverts the matching of the indices of the arrays using the UCIDs. Arrays should have indices that correspond to the UCID values. The return will be a list of arrays with indices that correspond to the original indices of the arrays / ucids. Essentially, this function undoes the matching done by match_arrays_with_ucids(). Args: arrays (Union[np.ndarray, List[np.ndarray]]): List of numpy arrays for each session. ucids (Union[List[np.ndarray], List[List[int]]]): List of lists of UCIDs for each session. unsqueeze (bool): If ``True``, then this algorithm assumes that the arrays were squeezed to remove unused UCIDs. This corresponds to and should match the argument ``squeeze`` used in match_arrays_with_ucids(). Returns: (List[Union[np.ndarray, scipy.sparse.lil_array]]): arrays_out (List[Union[np.ndarray, scipy.sparse.lil_array]]): List of arrays with indices that correspond to the original indices of the arrays / ucids. """ arrays = [arrays] if not isinstance(arrays, list) else arrays ## Make a mapping of the UCIDs to the original indices ('aranges_matched') ucids_clean = copy.deepcopy(ucids) ucids_clean = check_dataStructure__list_ofListOrArray_ofDtype( lolod=ucids_clean, dtype=np.float32, fix=True, verbose=False, ) aranges = [np.arange(len(u), dtype=np.float32) for u in ucids_clean] aranges_matched = match_arrays_with_ucids( arrays=aranges, ucids=ucids_clean, squeeze=False, ) ## Make sure that unsqueeze is consistent with the arrays flag_same_len = all([len(u) == len(a) for u, a in zip(aranges_matched, arrays)]) if unsqueeze == False: assert flag_same_len == True else: assert flag_same_len == False # Unsqueeze arrays if unsqueeze: idx_unsq = [(np.cumsum(~np.isnan(a)) - 1).astype(np.float32) for a in aranges_matched] for ii, a in enumerate(aranges_matched): idx_unsq[ii][np.isnan(a)] = np.nan arrays_unsq = [helpers.index_with_nans(a, idx) for a, idx in zip(arrays, idx_unsq)] else: arrays_unsq = arrays ## Invert the matching def negOne_to_nan(x): tmp = np.array(x, dtype=np.float32) np.place(arr=tmp, mask=tmp == -1, vals=np.nan) return tmp ucids_clean_nan = [negOne_to_nan(u) for u in ucids_clean] arrays_inv = [helpers.index_with_nans(a, o) for a, o in zip(arrays_unsq, ucids_clean_nan)] return arrays_inv
[docs] def labels_to_labelsBySession(labels, n_roi_bySession): """ Converts a list of labels to a list of lists of labels by session. RH 2024 Args: labels (list or np.ndarray): List of labels. n_roi_bySession (list or np.ndarray): Number of ROIs by session. Returns: (list): List of lists of labels by session. """ assert isinstance(labels, (list, np.ndarray)), f'labels is not a list or np.ndarray. labels={labels}' assert isinstance(n_roi_bySession, (list, np.ndarray)), f'n_roi_bySession is not a list or np.ndarray. n_roi_bySession={n_roi_bySession}' labels = np.array(labels) n_roi_bySession = np.array(n_roi_bySession, dtype=np.int64) assert labels.ndim == 1, f'labels.ndim={labels.ndim}, but should be 1.' assert n_roi_bySession.ndim == 1, f'n_roi_bySession.ndim={n_roi_bySession.ndim}, but should be 1.' assert np.sum(n_roi_bySession) == len(labels), f'np.sum(n_roi_bySession)={np.sum(n_roi_bySession)} != len(labels)={len(labels)}' labels_bySession = split_iby_session(x=labels, n_roi_per_session=n_roi_bySession) return labels_bySession
[docs] def invert_ucids( ucids: Union[np.ndarray, list], max_ucid: int = None, ) -> Union[np.ndarray, list]: """ Invert UCIDs to make ucids_inverse where ucids_inverse[i] = argwhere(ucids == i) for all i in range(len(ucids)). Elements with UCID=-1 are discarded. Missing ucid values are set to -1. RH 2025 Args: ucids (Union[np.ndarray, list]): UCIDs to invert. Should be a 1D array or list of integers. Should be the ucids from a single session. max_ucid (int, optional): Maximum UCID value to use. If not provided, it will be inferred from the input UCIDs. If provided, it should be greater than or equal to the maximum UCID in the input. This is useful if you are combining multiple sessions with different UCID ranges. (Default is ``None``) Returns: (Union[np.ndarray, list]): Inverted UCIDs where ucids_inverse[i] = argwhere(ucids == i) for all i in range(len(ucids)). """ if isinstance(ucids, list): ucids = np.array(ucids, dtype=np.int64) flag_list = True elif not isinstance(ucids, np.ndarray): raise ValueError(f'ROICaT ERROR: ucids must be a list or numpy array, but got {type(ucids)}.') else: flag_list = False if max_ucid is None: max_ucid = np.max(ucids) if len(ucids) > 0 else 0 elif max_ucid < np.max(ucids): raise ValueError(f'ROICaT ERROR: n_ucids={max_ucid} is less than the maximum UCID in the input ucids={np.max(ucids)}. Please provide a larger max_ucid value.') ucids_inverse = np.full((max_ucid + 1,), -1, dtype=np.int64) for i in range(len(ucids)): if ucids[i] >= 0: if ucids_inverse[ucids[i]] == -1: ucids_inverse[ucids[i]] = i else: raise ValueError(f'ROICaT ERROR: Duplicate UCID {ucids[i]} found at indices {ucids_inverse[ucids[i]]} and {i}. UCIDs must be unique.') if flag_list: ucids_inverse = ucids_inverse.tolist() else: ucids_inverse = ucids_inverse.astype(np.int64) return ucids_inverse