Source code for roicat.pipelines

## Import basic libraries
from pathlib import Path
import copy
import tempfile
from IPython.display import display
import time

# import matplotlib.pyplot as plt
import numpy as np

## Import roicat submodules
from . import data_importing, ROInet, helpers, util, visualization, tracking, classification

[docs] def pipeline_tracking(params: dict, custom_data: data_importing.Data_roicat = None) -> tuple: """ Pipeline for tracking ROIs across sessions. RH 2023 Args: params (dict): Dictionary of parameters. See ``roicat.util.get_default_parameters(pipeline='tracking')`` for details. custom_data: Optional[data_importing.Data_roicat] Optional. If not None, then this is a custom roicat data object that will be used instead of loading data known formats from disk. Be careful to ensure that the object is prepared for the tracking pipeline. Returns: (tuple): tuple containing: results (dict): Dictionary of results. run_data (dict): Dictionary containing the different class objects used in the pipeline. params (dict): Parameters used in the pipeline. See ``roicat.helpers.prepare_params()`` for details. """ ## Start timer tocs = [] tic_start = time.time() tocs.append(('start_time_absolute', time.time() - tic_start)) ## Prepare params defaults = util.get_default_parameters(pipeline='tracking') params = helpers.prepare_params(params, defaults, verbose=True) display(params) ## Prepare state variables VERBOSE = params['general']['verbose'] DEVICE = helpers.set_device(use_GPU=params['general']['use_GPU']) SEED = util.set_random_seed( seed=params['general']['random_seed'], deterministic=params['general']['random_seed'] is not None, ) data_kind = params['data_loading']['data_kind'] data_kind_aliases = { 'roicat': 'data_roicat', 'data_roicat': 'data_roicat', 'suite2p': 'data_suite2p', 'data_suite2p': 'data_suite2p', } data_kind = data_kind_aliases.get(data_kind, data_kind) if custom_data is not None: print("Using custom data object.") data = custom_data elif data_kind == 'data_suite2p': assert params['data_loading']['dir_outer'] is not None, f"params['data_loading']['dir_outer'] must be specified if params['data_loading']['data_kind'] is 'data_suite2p'." paths_allStat = helpers.find_paths( dir_outer=params['data_loading']['dir_outer'], reMatch='stat.npy', reMatch_in_path=params['data_loading']['reMatch_in_path'], depth=6, find_files=True, find_folders=False, natsorted=True, )[:] paths_allOps = [str(Path(path).resolve().parent / 'ops.npy') for path in paths_allStat][:] if len(paths_allStat) == 0: raise FileNotFoundError(f"No stat.npy files found in '{params['data_loading']['dir_outer']}'") print(f"Found the following stat.npy files:") [print(f" {path}") for path in paths_allStat] print(f"Found the following corresponding ops.npy files:") [print(f" {path}") for path in paths_allOps] ## Import data data = data_importing.Data_suite2p( paths_statFiles=paths_allStat[:], paths_opsFiles=paths_allOps[:], verbose=VERBOSE, **{**params['data_loading']['common'], **params['data_loading']['data_suite2p']}, ) assert data.check_completeness(verbose=False)['tracking'], f"Data object is missing attributes necessary for tracking." elif data_kind == 'data_roicat': ## Search for both directories (.richfile) and archive files (.sqlar, .zip, .tar) paths_folders = helpers.find_paths( dir_outer=params['data_loading']['dir_outer'], reMatch=params['data_loading']['data_roicat']['filename_search'], depth=6, find_files=False, find_folders=True, natsorted=True, )[:] paths_files = helpers.find_paths( dir_outer=params['data_loading']['dir_outer'], reMatch=params['data_loading']['data_roicat']['filename_search'], depth=6, find_files=True, find_folders=False, natsorted=True, )[:] paths_allDataObjs = paths_folders + paths_files assert len(paths_allDataObjs) == 1, f"ERROR: Found {len(paths_allDataObjs)} files matching the search pattern '{params['data_loading']['data_roicat']['filename_search']}' in '{params['data_loading']['dir_outer']}'. Exactly one file must be found." data = data_importing.Data_roicat() data.import_from_dict( dict_load=util.RichFile_ROICaT(path=paths_allDataObjs[0]).load(), ) else: raise NotImplementedError(f"params['data_loading']['data_kind'] == '{params['data_loading']['data_kind']}' is not yet implemented.") assert data.check_completeness(verbose=False)['tracking'], f"Data object is missing attributes necessary for tracking." assert data.n_sessions > 1, f"Data object must have more than one session to track (n_sessions={data.n_sessions})." tocs.append(('data_loading', time.time() - tic_start)) ## Alignment print(f"\n{'='*50}\nStep: Alignment\n{'='*50}") if VERBOSE else None aligner = tracking.alignment.Aligner( um_per_pixel=data.um_per_pixel[0], ## Single value for um_per_pixel. data.um_per_pixel is typically a list of floats, so index out just one value. verbose=VERBOSE, ## Whether to print updates device=DEVICE, **params['alignment']['initialization'] ) FOV_images = aligner.augment_FOV_images( FOV_images=data.FOV_images, spatialFootprints=data.spatialFootprints, **params['alignment']['augment'], ) aligner.fit_geometric( ims_moving=FOV_images, ## input images verbose=VERBOSE, ## Whether to print updates **params['alignment']['fit_geometric'], ) aligner.transform_images_geometric(FOV_images); if params['alignment']['fit_nonrigid']['method']: aligner.fit_nonrigid( ims_moving=aligner.ims_registered_geo, ## Input images. Typically the geometrically registered images remappingIdx_init=aligner.remappingIdx_geo, ## The remappingIdx between the original images (and ROIs) and ims_moving **params['alignment']['fit_nonrigid'], ) aligner.transform_images_nonrigid(FOV_images); aligner.transform_ROIs( ROIs=data.spatialFootprints, remappingIdx=aligner.remappingIdx_nonrigid, **params['alignment']['transform_ROIs'], ); else: aligner.transform_ROIs( ROIs=data.spatialFootprints, remappingIdx=aligner.remappingIdx_geo, **params['alignment']['transform_ROIs'], ); tocs.append(('alignment', time.time() - tic_start)) helpers.clear_gpu_cache() ## Blur ROIs print(f"\n{'='*50}\nStep: Blurring\n{'='*50}") if VERBOSE else None blurrer = tracking.blurring.ROI_Blurrer( frame_shape=(data.FOV_height, data.FOV_width), ## FOV height and width plot_kernel=False, ## Whether to visualize the 2D gaussian **params['blurring'], ) blurrer.blur_ROIs( spatialFootprints=aligner.ROIs_aligned[:], ) tocs.append(('blurring', time.time() - tic_start)) ## ROInet embedding print(f"\n{'='*50}\nStep: ROInet Embedding\n{'='*50}") if VERBOSE else None dir_temp = tempfile.gettempdir() roinet = ROInet.ROInet_embedder( device=DEVICE, ## Which torch device to use ('cpu', 'cuda', etc.) dir_networkFiles=dir_temp, ## Directory to download the pretrained network to verbose=VERBOSE, ## Whether to print updates **params['ROInet']['network'], ) roinet.generate_dataloader( ROI_images=data.ROI_images, ## Input images of ROIs um_per_pixel=data.um_per_pixel, ## Resolution of FOV pref_plot=False, ## Whether or not to plot the ROI sizes **params['ROInet']['dataloader'], ); roinet.generate_latents(); tocs.append(('ROInet', time.time() - tic_start)) helpers.clear_gpu_cache() ## Scattering wavelet embedding print(f"\n{'='*50}\nStep: Scattering Wavelet Transform\n{'='*50}") if VERBOSE else None swt = tracking.scatteringWaveletTransformer.SWT( image_shape=data.ROI_images[0].shape[1:3], ## size of a cropped ROI image device=DEVICE, ## PyTorch device kwargs_Scattering2D=params['SWT']['kwargs_Scattering2D'], ) swt.transform( ROI_images=roinet.ROI_images_rs, ## All the cropped and resized ROI images batch_size=params['SWT']['batch_size'], ); tocs.append(('SWT', time.time() - tic_start)) helpers.clear_gpu_cache() ## Compute similarities print(f"\n{'='*50}\nStep: Similarity Graph\n{'='*50}") if VERBOSE else None from .tracking.similarity_graph import DEFAULT_METRICS sim = tracking.similarity_graph.ROI_graph( frame_height=data.FOV_height, frame_width=data.FOV_width, verbose=VERBOSE, ## Whether to print outputs metric_configs=DEFAULT_METRICS, **params['similarity_graph']['sparsification'] ) sim.compute_similarity_blockwise( spatialFootprints=blurrer.ROIs_blurred, ## Mask spatial footprints ROI_session_bool=data.session_bool, ## Boolean array of which ROIs belong to which sessions features={'nn': roinet.latents, 'swt': swt.latents}, ## Feature vectors per metric **params['similarity_graph']['compute_similarity'], ); sim.make_normalized_similarities( centers_of_mass=data.centroids, ## ROI centroid positions features={'nn': roinet.latents, 'swt': swt.latents}, ## Feature vectors for z-scoring device=DEVICE, k_max=data.n_sessions * params['similarity_graph']['normalization']['k_max'], k_min=data.n_sessions * params['similarity_graph']['normalization']['k_min'], algo_NN=params['similarity_graph']['normalization']['algo_NN'], ) tocs.append(('similarity_graph', time.time() - tic_start)) helpers.clear_gpu_cache() ## Clustering print(f"\n{'='*50}\nStep: Clustering\n{'='*50}") if VERBOSE else None clusterer = tracking.clustering.Clusterer( similarities=sim.similarities_final, metric_configs=DEFAULT_METRICS, s_sesh=sim.s_sesh, session_bool=data.session_bool, verbose=VERBOSE, ) mixing_method = params['clustering'].get('mixing_method', 'automatic') assert mixing_method in ('automatic', 'manual'), ( f"clustering.mixing_method must be 'automatic' or 'manual', got '{mixing_method}'" ) if mixing_method == 'manual': kwargs_makeConjunctiveDistanceMatrix_best = params['clustering']['parameters_manual_mixing'] else: ## Default: NB calibration → freeze-sigmoid → 3-param DE kwargs_makeConjunctiveDistanceMatrix_best = clusterer.find_optimal_parameters_for_pruning( seed=SEED, **params['clustering'].get('parameters_automatic_mixing', {}), ) clusterer.make_pruned_similarity_graphs( mixing_params=kwargs_makeConjunctiveDistanceMatrix_best, **params['clustering']['pruning'], ) tocs.append(('make_conjunctive_distance', time.time() - tic_start)) def choose_clustering_method(method='automatic', n_sessions_switch=8, n_sessions=None): if method == 'automatic': method_out = 'hdbscan'.upper() if n_sessions >= n_sessions_switch else 'sequential_hungarian'.upper() else: method_out = method.upper() assert method_out.upper() in ['hdbscan'.upper(), 'sequential_hungarian'.upper()] return method_out method_clustering = choose_clustering_method( method=params['clustering']['cluster_method']['method'], n_sessions_switch=params['clustering']['cluster_method']['n_sessions_switch'], n_sessions=data.n_sessions, ) if method_clustering == 'hdbscan'.upper(): labels = clusterer.fit( d_conj=clusterer.dConj_pruned, ## Input distance matrix session_bool=data.session_bool, ## Boolean array of which ROIs belong to which sessions **params['clustering']['hdbscan'], ) elif method_clustering == 'sequential_hungarian'.upper(): labels = clusterer.fit_sequentialHungarian( d_conj=clusterer.dConj_pruned, ## Input distance matrix session_bool=data.session_bool, ## Boolean array of which ROIs belong to which sessions **params['clustering']['sequential_hungarian'], ) else: raise ValueError('Clustering method not recognized. This should never happen.') tocs.append(('clustering', time.time() - tic_start)) quality_metrics = clusterer.compute_quality_metrics(); tocs.append(('quality_metrics', time.time() - tic_start)) ## Collect results labels_squeezed, labels_bySession, labels_bool, labels_bool_bySession, labels_dict = tracking.clustering.make_label_variants(labels=labels, n_roi_bySession=data.n_roi) results_clusters = { 'labels': labels_squeezed, 'labels_bySession': labels_bySession, 'labels_dict': labels_dict, 'quality_metrics': quality_metrics, } results_all = { "clusters":{ "labels": util.JSON_List(labels_squeezed), "labels_bySession": util.JSON_List(labels_bySession), "labels_bool": labels_bool, "labels_bool_bySession": labels_bool_bySession, "labels_dict": util.JSON_Dict(labels_dict), "quality_metrics": util.JSON_Dict(clusterer.quality_metrics) if hasattr(clusterer, 'quality_metrics') else None, }, "ROIs": { "ROIs_aligned": aligner.ROIs_aligned, "ROIs_raw": data.spatialFootprints, "frame_height": data.FOV_height, "frame_width": data.FOV_width, "idx_roi_session": np.where(data.session_bool)[1], "n_sessions": data.n_sessions, }, "input_data": { "paths_stat": data.paths_stat if hasattr(data, 'paths_stat') else None, "paths_ops": data.paths_ops if hasattr(data, 'paths_ops') else None, }, "other": { "run_times": util.JSON_Dict(tocs), }, } run_data = { 'data': data.__dict__, 'aligner': aligner.__dict__, 'blurrer': blurrer.__dict__, 'roinet': roinet.__dict__, 'swt': swt.__dict__, 'sim': sim.__dict__, 'clusterer': clusterer.__dict__, } params_used = {name: mod['params'] for name, mod in run_data.items()} ## Print some results print(f'Number of clusters: {len(np.unique(results_clusters["labels"]))}') print(f'Number of discarded ROIs: {(np.array(results_clusters["labels"])==-1).sum()}') ## Save results if params['results_saving']['dir_save'] is not None: dir_save = Path(params['results_saving']['dir_save']).resolve() name_save = str(params['results_saving']['prefix_name_save']) richfile_backend = params['results_saving'].get('richfile_backend', 'zip') ## Map backend to file extension: archive backends use .richfile.<ext>, ## directory backend uses just .richfile _backend_suffix = {'directory': 'richfile', 'sqlar': 'richfile.sqlar', 'zip': 'richfile.zip', 'tar': 'richfile.tar'} rf_ext = _backend_suffix.get(richfile_backend, 'richfile') print(f'dir_save: {dir_save}') paths_save = { 'results_clusters': str(Path(dir_save) / f'{name_save}.tracking.results_clusters.json'), 'params_used': str(Path(dir_save) / f'{name_save}.tracking.params_used.json'), 'results_all': str(Path(dir_save) / f'{name_save}.tracking.results_all.{rf_ext}'), 'run_data': str(Path(dir_save) / f'{name_save}.tracking.run_data.{rf_ext}'), } Path(dir_save).mkdir(parents=True, exist_ok=True) helpers.json_save(obj=results_clusters, filepath=paths_save['results_clusters']); helpers.json_save(obj=params_used, filepath=paths_save['params_used']); util.RichFile_ROICaT(path=paths_save['results_all'], backend=richfile_backend).save(obj=results_all, overwrite=True); util.RichFile_ROICaT(path=paths_save['run_data'], backend=richfile_backend).save(obj=run_data, overwrite=True); ## Visualize results ### Save some figures #### Save FOV_images as .png files def save_image(array, path, normalize=True): ## Use PIL to save the image from PIL import Image Path(path).parent.mkdir(parents=True, exist_ok=True) Image.fromarray((np.array(array / array.max() if normalize else array) * 255).astype(np.uint8)).save(path) [save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'FOV_images' / f'FOV_images_{ii}.png') ) for ii, array in enumerate(data.FOV_images)] [save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'FOV_images_aligned_geometric' / f'FOV_images_aligned_geometric_{ii}.png') ) for ii, array in enumerate(aligner.ims_registered_geo)] if params['alignment']['fit_nonrigid']['method']: [save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'FOV_images_aligned_nonrigid' / f'FOV_images_aligned_nonrigid_{ii}.png') ) for ii, array in enumerate(aligner.ims_registered_nonrigid)] [save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'ROIs' / f'ROIs_{ii}.png') ) for ii, array in enumerate(data.get_maxIntensityProjection_spatialFootprints())] [save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'ROIs_aligned' / f'ROIs_aligned_{ii}.png') ) for ii, array in enumerate(aligner.get_ROIsAligned_maxIntensityProjection(normalize=True))] [save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'ROIs_aligned_blurred' / f'ROIs_aligned_blurred_{ii}.png') ) for ii, array in enumerate(blurrer.get_ROIsBlurred_maxIntensityProjection())] #### Save the image alignment checker images fig_all_to_all, fig_direct = aligner.plot_alignment_results_geometric() (Path(dir_save).resolve() / 'visualization' / 'alignment').mkdir(parents=True, exist_ok=True) fig_all_to_all.savefig(str(Path(dir_save).resolve() / 'visualization' / 'alignment' / 'all_to_all_geometric.png')) fig_direct.savefig(str(Path(dir_save).resolve() / 'visualization' / 'alignment' / 'direct_geometric.png')) if fig_direct is not None else None if params['alignment']['fit_nonrigid']['method']: fig_all_to_all, _ = aligner.plot_alignment_results_nonrigid() fig_all_to_all.savefig(str(Path(dir_save).resolve() / 'visualization' / 'alignment' / 'all_to_all_nonrigid.png')) #### Save some sample ROI images [save_image(array, str(Path(dir_save).resolve() / 'visualization' / 'ROIs_sample' / f'ROIs_sample_{ii}.png') ) for ii, array in enumerate(roinet.ROI_images_rs[:100])] #### Save the similarity graphy blocks fig = sim.visualize_blocks() (Path(dir_save).resolve() / 'visualization' / 'similarity_graph').mkdir(parents=True, exist_ok=True) fig.savefig(str(Path(dir_save).resolve() / 'visualization' / 'similarity_graph' / 'blocks.png')) #### Save the similarity / distance plots for the given conjunctive distance matrix kwargs fig = clusterer.plot_distSame(mixing_params=kwargs_makeConjunctiveDistanceMatrix_best) (Path(dir_save).resolve() / 'visualization' / 'clustering').mkdir(parents=True, exist_ok=True) fig.savefig(str(Path(dir_save).resolve() / 'visualization' / 'clustering' / 'dist.png')) fig, axs = clusterer.plot_similarity_relationships( max_samples=100000, ## Make smaller if it is running too slow kwargs_scatter={'s':1, 'alpha':0.2}, mixing_params=kwargs_makeConjunctiveDistanceMatrix_best, ) (Path(dir_save).resolve() / 'visualization' / 'clustering').mkdir(parents=True, exist_ok=True) fig.savefig(str(Path(dir_save).resolve() / 'visualization' / 'clustering' / 'similarity_relationships.png')) #### Save the clustering results fig, axs = tracking.clustering.plot_quality_metrics( quality_metrics=quality_metrics, labels=labels_squeezed, n_sessions=data.n_sessions, ) (Path(dir_save).resolve() / 'visualization' / 'clustering').mkdir(parents=True, exist_ok=True) fig.savefig(str(Path(dir_save).resolve() / 'visualization' / 'clustering' / 'quality_metrics.png')) ### Save a gif of the ROIs FOV_clusters = visualization.compute_colored_FOV( spatialFootprints=[r.power(1.0) for r in results_all['ROIs']['ROIs_aligned']], ## Spatial footprint sparse arrays FOV_height=results_all['ROIs']['frame_height'], FOV_width=results_all['ROIs']['frame_width'], labels=results_all["clusters"]["labels_bySession"], ## cluster labels # labels=(np.array(results["clusters"]["labels"])!=-1).astype(np.int64), ## cluster labels # alphas_labels=confidence*1.5, ## Set brightness of each cluster based on some 1-D array # alphas_labels=(clusterer.quality_metrics['cluster_silhouette'] > 0) * (clusterer.quality_metrics['cluster_intra_means'] > 0.4), # alphas_sf=clusterer.quality_metrics['sample_silhouette'], ## Set brightness of each ROI based on some 1-D array ) helpers.save_gif( array=helpers.add_text_to_images( images=[(f * 255).astype(np.uint8) for f in FOV_clusters], text=[[f"{ii}",] for ii in range(len(FOV_clusters))], font_size=3, line_width=10, position=(30, 90), ), path=str(Path(dir_save).resolve() / 'visualization' / 'FOV_clusters.gif'), frameRate=params['results_saving']['gif_frame_rate'], loop=0, ) ### Save gifs of the FOVs at different stages of alignment helpers.save_gif( array=helpers.add_text_to_images( images=[((f / np.max(f)) * 255).astype(np.uint8) for f in FOV_images], text=[[f"{ii}",] for ii in range(len(FOV_clusters))], font_size=3, line_width=10, position=(30, 90), ), path=str(Path(dir_save).resolve() / 'visualization' / 'FOV_images' / 'FOV_images.gif'), frameRate=params['results_saving']['gif_frame_rate'], loop=0, ) helpers.save_gif( array=helpers.add_text_to_images( images=[((f / np.max(f)) * 255).astype(np.uint8) for f in aligner.ims_registered_geo], text=[[f"{ii}",] for ii in range(len(FOV_clusters))], font_size=3, line_width=10, position=(30, 90), ), path=str(Path(dir_save).resolve() / 'visualization' / 'FOV_images_aligned_geometric' / 'FOV_images_aligned_geometric.gif'), frameRate=params['results_saving']['gif_frame_rate'], loop=0, ) if params['alignment']['fit_nonrigid']['method']: helpers.save_gif( array=helpers.add_text_to_images( images=[((f / np.max(f)) * 255).astype(np.uint8) for f in aligner.ims_registered_nonrigid], text=[[f"{ii}",] for ii in range(len(FOV_clusters))], font_size=3, line_width=10, position=(30, 90), ), path=str(Path(dir_save).resolve() / 'visualization' / 'FOV_images_aligned_nonrigid' / 'FOV_images_aligned_nonrigid.gif'), frameRate=params['results_saving']['gif_frame_rate'], loop=0, ) ## Print timing summary tocs.append(('total', time.time() - tic_start)) if VERBOSE: print("\n" + "=" * 50) print("Pipeline Timing Summary") print("=" * 50) prev_time = 0 for name, cumulative_time in tocs: if name == 'start_time_absolute': continue step_time = cumulative_time - prev_time print(f" {name:30s} {step_time:8.1f}s ({cumulative_time:8.1f}s cumulative)") prev_time = cumulative_time print("=" * 50) return results_all, run_data, params