Source code for roicat.ROInet

# Documentation style guide:
## - Class init docstrings should be in the class definition docstring.
## - Use a style guide similar to Google's Python style guide except that argument definitions should start on a new indented line after the argument name.
## - If there is more than one argument, use multiple lines for the argument definition code.
## - Example parameters should start on a new line ('\n' should be used before the first one), should start with a dash, and the parameter definition should start on a new indented line.
## - All arguments should have type hints, accurately reflecting the expected type of the argument.
## - Special inputs or conditions related to the arguments should be highlighted using bold for emphasis, italic for optional aspects, and code for specific values or code-related inputs.
## - Keep the return variable name in the docstring for clarity.
## - Keep a consistent line length to improve readability of the docstring.
## - Ensure the clarity of argument descriptions through the use of clear sentence structure and punctuation.

"""
OSF.io links to ROInet versions:

* ROInet_tracking:
    * Info: This version does not include occlusions or large
      affine transformations.
    * Link: https://osf.io/x3fd2/download
    * Hash (MD5 hex): 7a5fb8ad94b110037785a46b9463ea94
* ROInet_classification:
    * Info: This version includes occlusions and large affine
      transformations.
    * Link: https://osf.io/c8m3b/download
    * Hash (MD5 hex): 357a8d9b630ec79f3e015d0056a4c2d5
"""


import sys
from pathlib import Path
import json
import os
import hashlib
import PIL
import multiprocessing as mp
from functools import partial
import gc
from typing import List, Tuple, Union, Optional, Dict, Any, Callable

import numpy as np
import torch
import torchvision
from torch.nn import Module
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import scipy.signal
import warnings

from . import util, helpers, data_importing

[docs] class Resizer_ROI_images(util.ROICaT_Module): """ Class for resizing ROIs. RH 2023-2024 Args: function_scaleFactor (Callable): The function used to convert ``um_per_pixel`` to a scale factor. (Default is ``lambda um_per_pixel, size_im: 1.2 * um_per_pixel * (size_im / 36)``) Where ``um_per_pixel`` is the number of microns per pixel and size_im is the edge length of the image. nan_to_num (bool): Whether to replace NaNs with a specific value. (Default is ``True``) nan_to_num_val (float): The value to replace NaNs with. (Default is *0.0*) verbose (bool): If True, print out extra information. (Default is ``False``) """ def __init__( self, function_scaleFactor: Callable[[float, int], float]=lambda um_per_pixel, size_im: 1.2 * um_per_pixel * (size_im / 36), nan_to_num: bool=True, nan_to_num_val: float=0.0, verbose: bool=True, batch_size: int=10000, ): super().__init__() self.nan_to_num = nan_to_num self.nan_to_num_val = nan_to_num_val self.batch_size = batch_size self._verbose = verbose ## Store parameter (but not data) args as attributes self.params['__init__'] = self._locals_to_params( locals_dict=locals(), keys=[ 'nan_to_num', 'nan_to_num_val', ], ) self.function_scaleFactor = function_scaleFactor def _check_ROI_images(self, ROI_images: np.ndarray): ### Check if any NaNs if np.any(np.isnan(ROI_images)): if self.nan_to_num: warnings.warn('ROICaT WARNING: NaNs detected. You should consider removing these before passing to the network. Using nan_to_num arguments.') else: raise ValueError('ROICaT ERROR: NaNs detected. You should consider removing these before passing to the network. Use nan_to_num=True to ignore this error.') if np.any(np.isinf(ROI_images)): warnings.warn('ROICaT WARNING: Infs detected. You should consider removing these before passing to the network.') ## Check if any images in any of the sessions are all zeros if np.any(np.all(ROI_images==0, axis=(1,2))): warnings.warn('ROICaT WARNING: Image(s) with all zeros detected. These can pass through the network, but may give weird results.')
[docs] def plot_resized_comparison(self, ROI_images_cat: np.ndarray, ROI_images_rs: np.ndarray): """ Plot a comparison of the ROI sizes before and after resizing. Args: ROI_images_cat (np.ndarray): Array of ROIs to resize. Shape should be (nROIs, height, width). ROI_images_rs (np.ndarray): Array of resized ROIs. Shape should be (nROIs, height, width). """ fig, axs = plt.subplots(2,1, figsize=(7,10)) axs[0].plot(np.mean(ROI_images_cat > 0, axis=(1,2))) axs[0].plot(scipy.signal.savgol_filter(np.mean(ROI_images_cat > 0, axis=(1,2)), 501, 3)) axs[0].set_xlabel('ROI number'); axs[0].set_ylabel('mean npix'); axs[0].set_title('ROI sizes raw') axs[1].plot(np.mean(ROI_images_rs > 0, axis=(1,2))) axs[1].plot(scipy.signal.savgol_filter(np.mean(ROI_images_rs > 0, axis=(1,2)), 501, 3)) axs[1].set_xlabel('ROI number'); axs[1].set_ylabel('mean npix'); axs[1].set_title('ROI sizes resized')
[docs] def resize_ROIs( self, ROI_images: np.ndarray, # Array of shape (n_rois, height, width) um_per_pixel: float, ) -> np.ndarray: """ Resizes the ROI (Region of Interest) images to prepare them for pass through network. Args: ROI_images (np.ndarray): The ROI images to resize. Array of shape *(n_rois, height, width)*. um_per_pixel (float): The number of microns per pixel. This value is used to rescale the ROI images so that they occupy a standard region of the image frame. Returns: (np.ndarray): ROI_images_rs (np.ndarray): The resized ROI images. """ ## Store parameter (but not data) args as attributes self.params['resize_ROIs'] = self._locals_to_params( locals_dict=locals(), keys=[ 'um_per_pixel', ], ) self._check_ROI_images(ROI_images) assert isinstance(um_per_pixel, (int, float)), f'um_per_pixel should be an int or float, but is {type(um_per_pixel)}' if self.nan_to_num: print(f'ROICaT: replacing NaNs with {self.nan_to_num_val}') if self._verbose else None ROI_images = np.nan_to_num(ROI_images, nan=self.nan_to_num_val) scale_forRS = self.function_scaleFactor(um_per_pixel=float(um_per_pixel), size_im=ROI_images.shape[1]) print(f'ROICaT: resizing ROIs') if self._verbose else None return np.stack([resize_affine(img, scale=scale_forRS, clamp_range=True) for img in tqdm(ROI_images, mininterval=5, disable=not self._verbose)], axis=0)
## Faster but slightly different results # return np.concatenate( # [resize_images( # batch, # scale=scale_forRS, # clamp_range=True, # ) for batch in tqdm( # helpers.make_batches(ROI_images, batch_size=self.batch_size), # total=np.ceil(len(ROI_images)/self.batch_size), # mininterval=5, # unit='images', # unit_scale=self.batch_size, # disable=not self._verbose, # )], axis=0)
[docs] class Dataloader_ROInet(util.ROICaT_Module): """ Class for creating a dataloader for the ROInet network. JZ, RH 2023 Args: ROI_images (np.ndarray): Array of ROIs to resize. Shape should be (nROIs, height, width). pref_plot (bool): If ``True``, plots the sizes of the ROI images before and after normalization. (Default is ``False``) batchSize_dataloader (int): The batch size to use for the DataLoader. (Default is *8*) pinMemory_dataloader (bool): If ``True``, pins the memory of the DataLoader, as per PyTorch's best practices. (Default is ``True``) numWorkers_dataloader (int): The number of worker processes for data loading. (Default is *-1*) persistentWorkers_dataloader (bool): If ``True``, uses persistent worker processes. (Default is ``True``) prefetchFactor_dataloader (int): The prefetch factor for data loading. (Default is *2*) transforms (Optional[Callable]): The transforms to use for the DataLoader. If ``None``, the function will only scale dynamic range (to 0-1), resize (to img_size_out dimensions), and tile channels (to 3) as a minimum to pass images through the network. (Default is ``None``) n_transforms (int): The number of times to apply the transforms to each image. Should be 1 for inference and 2 for training. (Default is *1*) img_size_out (Tuple[int, int]): The image output dimensions of DataLoader if transforms is ``None``. (Default is *(224, 224)*) jit_script_transforms (bool): If ``True``, converts the transforms pipeline into a TorchScript pipeline, potentially improving calculation speed but can cause problems with multiprocessing. (Default is ``False``) shuffle (bool): If ``True``, shuffles the data. Should be set to ``True`` for SimCLR training. (Default is ``False``) drop_last (bool): If ``True``, drops the last batch if it is not full. Should be set to ``True`` for SimCLR training. (Default is ``False``) verbose (bool): If ``True``, print out extra information. (Default is ``True``) """ def __init__( self, ROI_images: np.ndarray, batchSize_dataloader: int = 8, pinMemory_dataloader: bool = True, numWorkers_dataloader: int = -1, persistentWorkers_dataloader: bool = True, prefetchFactor_dataloader: int = 2, transforms: Optional[Callable] = None, n_transforms: int = 1, img_size_out: Tuple[int, int] = (224, 224), jit_script_transforms: bool = False, shuffle_dataloader: bool = False, drop_last_dataloader: bool = False, verbose: bool = True, ): super().__init__() self._verbose = verbose numWorkers_dataloader = mp.cpu_count() if numWorkers_dataloader == -1 else numWorkers_dataloader ## Store parameter (but not data) args as attributes self.params['__init__'] = self._locals_to_params( locals_dict=locals(), keys=[ 'batchSize_dataloader', 'pinMemory_dataloader', 'numWorkers_dataloader', 'persistentWorkers_dataloader', 'prefetchFactor_dataloader', 'n_transforms', 'img_size_out', 'jit_script_transforms', 'shuffle_dataloader', 'drop_last_dataloader', 'verbose', ], ) ## Type checking / correction if not isinstance(img_size_out, (tuple, list)): assert isinstance(img_size_out, int), f'img_size_out should be a tuple or list, but is {type(img_size_out)}' img_size_out = (img_size_out, img_size_out) transforms = torch.nn.Sequential( ScaleDynamicRange(scaler_bounds=(0,1)), torchvision.transforms.Resize( size=img_size_out, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True, ), TileChannels(dim=0, n_channels=3), ) if transforms is None else transforms if jit_script_transforms: if numWorkers_dataloader > 0: warnings.warn("\n\nWarning: Converting transforms to a jit-based script has been known to cause issues on Windows when numWorkers_dataloader > 0. If self.generate_latents() raises an Exception similar to 'Tried to serialize object __torch__.torch.nn.modules.container.Sequential which does not have a __getstate__ method defined!' consider setting numWorkers_dataloader=0 or jit_script_transforms=False.\n") self.transforms = torch.jit.script(transforms) else: self.transforms = transforms print(f'Defined image transformations: {transforms}') if self._verbose else None self.dataset = dataset_simCLR( X=torch.as_tensor(ROI_images, device='cpu', dtype=torch.float32), y=torch.as_tensor(torch.zeros(ROI_images.shape[0]), device='cpu', dtype=torch.float32), n_transforms=n_transforms, transform=self.transforms, DEVICE='cpu', dtype_X=torch.float32, ) print(f'Defined dataset') if self._verbose else None self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=batchSize_dataloader, shuffle=shuffle_dataloader, drop_last=drop_last_dataloader, pin_memory=pinMemory_dataloader, num_workers=numWorkers_dataloader, persistent_workers=persistentWorkers_dataloader, prefetch_factor=prefetchFactor_dataloader, ) print(f'Defined dataloader') if self._verbose else None
[docs] class ROInet_embedder(util.ROICaT_Module): """ Class for loading the ROInet model, preparing data for it, and running it. RH, JZ 2022 OSF.io links to ROInet versions: * ROInet_tracking: * Info: This version does not include occlusions or large affine transformations. * Link: https://osf.io/x3fd2/download * Hash (MD5 hex): 7a5fb8ad94b110037785a46b9463ea94 * ROInet_classification: * Info: This version includes occlusions and large affine transformations. * Link: https://osf.io/c8m3b/download * Hash (MD5 hex): 357a8d9b630ec79f3e015d0056a4c2d5 Args: dir_networkFiles (str): Directory to find an existing ROInet.zip file or download and extract a new one into. device (str): Device to use for the model and data. (Default is ``'cpu'``) download_method (str): Approach to downloading the network files. Options are: \n * ``'check_local_first'``: Check if the network files are already in dir_networkFiles, if so, use them. * ``'force_download'``: Download an ROInet.zip file from download_url. * ``'force_local'``: Use an existing local copy of an ROInet.zip file, if they don't exist, raise an error. Hash checking is done and download_hash must be specified. \n (Default is ``'check_local_first'``) download_url (str): URL to download the ROInet.zip file from. (Default is https://osf.io/x3fd2/download) download_hash (dict): MD5 hash of the ROInet.zip file. This can be obtained from ROICaT documentation. If you don't have one, use download_method='force_download' and determine the hash using helpers.hash_file(). (Default is ``None``) names_networkFiles (dict): Names of the files in the ROInet.zip file. If uncertain, leave as None. The dictionary should have the form: \n ``{'params': 'params.json', 'model': 'model.py', 'state_dict': 'ConvNext_tiny__1_0_unfrozen__simCLR.pth',}`` \n Where 'params' is the parameters used to train the network (usually a .json file), 'model' is the model definition (usually a .py file), and 'state_dict' are the weights of the network (usually a .pth file). (Default is ``None``) forward_pass_version (str): Version of the forward pass to use. Options are 'latent' (return the post-head output latents, use this for tracking), 'head' (return the output of the head layers, use this for classification), and 'base' (return the output of the base model). (Default is ``'latent'``) verbose (bool): If True, print out extra information. (Default is ``True``) """ def __init__( self, dir_networkFiles: str, device: str = 'cpu', download_method: str = 'check_local_first', download_url: str = 'https://osf.io/x3fd2/download', download_hash: dict = None, names_networkFiles: dict = None, forward_pass_version: str = 'latent', verbose: bool = True, ): ## Imports super().__init__() ## Store parameter (but not data) args as attributes self.params['__init__'] = self._locals_to_params( locals_dict=locals(), keys=[ 'dir_networkFiles', 'device', 'download_method', 'download_url', 'download_hash', 'names_networkFiles', 'forward_pass_version', 'verbose', ], ) self._device = device self._verbose = verbose self._dir_networkFiles = dir_networkFiles self._download_url = download_url self._download_path_save = str(Path(self._dir_networkFiles).resolve() / 'ROInet.zip') fn_download = partial( helpers.download_file, path_save=self._download_path_save, hash_type='MD5', hash_hex=download_hash, mkdir=True, allow_overwrite=True, write_mode='wb', verbose=self._verbose, chunk_size=1024, ) ## Find or download network files if download_method == 'force_download': fn_download(url=self._download_url, check_local_first=False, check_hash=False) if download_method == 'check_local_first': # assert download_hash is not None, "if using download_method='check_local_first' download_hash cannot be None. Either determine the hash of the zip file or use download_method='force_download'." fn_download(url=self._download_url, check_local_first=True, check_hash=True) if download_method == 'force_local': # assert download_hash is not None, "if using download_method='force_local' download_hash cannot be None" assert Path(self._download_path_save).exists(), f"if using download_method='force_local' the network files must exist in {self._download_path_save}" fn_download(url=None, check_local_first=True, check_hash=True) ## Extract network files from zip paths_extracted = helpers.extract_zip( path_zip=self._download_path_save, path_extract=self._dir_networkFiles, verbose=self._verbose, ) ## Find network files if names_networkFiles is None: names_networkFiles = { 'params': 'params.json', 'model': 'model.py', 'state_dict': '.pth', } paths_networkFiles = {} paths_networkFiles['params'] = [p for p in paths_extracted if names_networkFiles['params'] in str(Path(p).name)][0] paths_networkFiles['model'] = [p for p in paths_extracted if names_networkFiles['model'] in str(Path(p).name)][0] paths_networkFiles['state_dict'] = [p for p in paths_extracted if names_networkFiles['state_dict'] in str(Path(p).name)][0] ## Import network files sys.path.append(str(Path(paths_networkFiles['model']).parent.resolve())) import model print(f"Imported model from {paths_networkFiles['model']}") if self._verbose else None with open(paths_networkFiles['params']) as f: self.params_model = json.load(f) print(f"Loaded params_model from {paths_networkFiles['params']}") if self._verbose else None self.net = model.make_model(fwd_version=forward_pass_version, **self.params_model) print(f"Generated network using params_model") if self._verbose else None ## Prep network and load state_dict for param in self.net.parameters(): param.requires_grad = False self.net.eval() self.net.load_state_dict(torch.load( f=paths_networkFiles['state_dict'], map_location=torch.device(self._device), weights_only=True, )) print(f'Loaded state_dict into network from {paths_networkFiles["state_dict"]}') if self._verbose else None self.net = self.net.to(self._device) print(f'Loaded network onto device {self._device}') if self._verbose else None def __repr__(self): device = self._device if hasattr(self, '_device') else '?' has_latents = hasattr(self, 'latents') n_latents = self.latents.shape[0] if has_latents else 0 return ( f"ROInet_embedder(device='{device}', " f"n_latents={n_latents if has_latents else 'not generated'})" )
[docs] def generate_dataloader( self, ROI_images: List[np.ndarray], um_per_pixel: Union[float, List[float]], resize_ROI_images: bool = True, nan_to_num: bool = True, nan_to_num_val: float = 0.0, pref_plot: bool = False, batchSize_dataloader: int = 8, pinMemory_dataloader: bool = True, numWorkers_dataloader: int = -1, persistentWorkers_dataloader: bool = True, prefetchFactor_dataloader: int = 2, transforms: Optional[Callable] = None, img_size_out: Tuple[int, int] = (224, 224), jit_script_transforms: bool = False, ): """ Generates a PyTorch DataLoader for a list of Region of Interest (ROI) images. Performs preprocessing such as rescaling, normalization, and resizing. Args: ROI_images (List[np.ndarray]): The ROI images to use for the dataloader. List of arrays, each array corresponds to a session and is of shape *(n_rois, height, width)*. um_per_pixel (Union[float, List[float]]): The conversion factor from pixels to microns. This is used to scale the ROI_images to a common size. Should either be a float or a list of floats, one for each session. resize_ROI_images (bool): If ``True``, resizes the ROI images to a common size. (Default is ``True``) nan_to_num (bool): Whether to replace NaNs with a specific value. (Default is ``True``) nan_to_num_val (float): The value to replace NaNs with. (Default is *0.0*) pref_plot (bool): If ``True``, plots the sizes of the ROI images before and after normalization. (Default is ``False``) batchSize_dataloader (int): The batch size to use for the DataLoader. (Default is *8*) pinMemory_dataloader (bool): If ``True``, pins the memory of the DataLoader, as per PyTorch's best practices. (Default is ``True``) numWorkers_dataloader (int): The number of worker processes for data loading. (Default is *-1*) persistentWorkers_dataloader (bool): If ``True``, uses persistent worker processes. (Default is ``True``) prefetchFactor_dataloader (int): The prefetch factor for data loading. (Default is *2*) transforms (Optional[Callable]): The transforms to use for the DataLoader. If ``None``, the function will only scale dynamic range (to 0-1), resize (to img_size_out dimensions), and tile channels (to 3) as a minimum to pass images through the network. (Default is ``None``) img_size_out (Tuple[int, int]): The image output dimensions of DataLoader if transforms is ``None``. (Default is *(224, 224)*) jit_script_transforms (bool): If ``True``, converts the transforms pipeline into a TorchScript pipeline, potentially improving calculation speed but can cause problems with multiprocessing. (Default is ``False``) Returns: (np.ndarray): ROI_images (np.ndarray): The ROI images after normalization and resizing. Shape is *(n_sessions, n_rois, n_channels, height, width)*. Example: .. highlight:: python .. code-block:: python dataloader = generate_dataloader(ROI_images) """ um_per_pixel = data_importing.Data_roicat._fix_um_per_pixel(um_per_pixel=um_per_pixel, n_sessions=len(ROI_images)) ROI_images = data_importing.Data_roicat._fix_ROI_images(ROI_images=ROI_images) ## Store parameter (but not data) args as attributes self.params['generate_dataloader'] = self._locals_to_params( locals_dict=locals(), keys=[ 'um_per_pixel', 'nan_to_num', 'nan_to_num_val', 'pref_plot', 'batchSize_dataloader', 'pinMemory_dataloader', 'numWorkers_dataloader', 'persistentWorkers_dataloader', 'prefetchFactor_dataloader', 'img_size_out', 'jit_script_transforms', ], ) if resize_ROI_images: print(f'Starting Image Resizer') if self._verbose else None roi_resizer = Resizer_ROI_images( nan_to_num=nan_to_num, nan_to_num_val=nan_to_num_val, verbose=False, ) self.ROI_images_rs = np.concatenate([ roi_resizer.resize_ROIs( ROI_images=ROI_images[ii], um_per_pixel=um_per_pixel[ii], ) for ii in range(len(ROI_images)) ], axis=0) roi_resizer.plot_resized_comparison( ROI_images_cat=np.concatenate(ROI_images, axis=0), ROI_images_rs=self.ROI_images_rs, ) if pref_plot else None else: self.ROI_images_rs = np.concatenate(ROI_images, axis=0) print(f'Creating dataloader') if self._verbose else None dataloader_generator = Dataloader_ROInet( ROI_images=self.ROI_images_rs, batchSize_dataloader=batchSize_dataloader, pinMemory_dataloader=pinMemory_dataloader, numWorkers_dataloader=numWorkers_dataloader, persistentWorkers_dataloader=persistentWorkers_dataloader, prefetchFactor_dataloader=prefetchFactor_dataloader, transforms=transforms, n_transforms=1, img_size_out=img_size_out, jit_script_transforms=jit_script_transforms, shuffle_dataloader=False, drop_last_dataloader=False, verbose=self._verbose, ) self.transforms = dataloader_generator.transforms self.dataset = dataloader_generator.dataset self.dataloader = dataloader_generator.dataloader return self.ROI_images_rs
[docs] def generate_latents(self) -> torch.Tensor: """ Passes the data in the dataloader through the network and generates latents. Returns: (torch.Tensor): latents (torch.Tensor): Latents for each ROI (Region of Interest). """ if hasattr(self, 'dataloader') == False: raise Exception('dataloader not defined. Call generate_dataloader() first.') print(f'starting: running data through network') self.latents = torch.cat([self.net(data[0][0].to(self._device)).detach() for data in tqdm(self.dataloader, mininterval=5)], dim=0).cpu() print(f'completed: running data through network') gc.collect() torch.cuda.empty_cache() gc.collect() torch.cuda.empty_cache() return self.latents
[docs] class ROInet_embedder_original(util.ROICaT_Module): """ Class for loading the ROInet model, preparing data for it, and running it. RH, JZ 2022 OSF.io links to ROInet versions: * ROInet_tracking: * Info: This version does not include occlusions or large affine transformations. * Link: https://osf.io/x3fd2/download * Hash (MD5 hex): 7a5fb8ad94b110037785a46b9463ea94 * ROInet_classification: * Info: This version includes occlusions and large affine transformations. * Link: https://osf.io/c8m3b/download * Hash (MD5 hex): 357a8d9b630ec79f3e015d0056a4c2d5 Args: dir_networkFiles (str): Directory to find an existing ROInet.zip file or download and extract a new one into. device (str): Device to use for the model and data. (Default is ``'cpu'``) download_method (str): Approach to downloading the network files. Options are: \n * ``'check_local_first'``: Check if the network files are already in dir_networkFiles, if so, use them. * ``'force_download'``: Download an ROInet.zip file from download_url. * ``'force_local'``: Use an existing local copy of an ROInet.zip file, if they don't exist, raise an error. Hash checking is done and download_hash must be specified. \n (Default is ``'check_local_first'``) download_url (str): URL to download the ROInet.zip file from. (Default is https://osf.io/x3fd2/download) download_hash (dict): MD5 hash of the ROInet.zip file. This can be obtained from ROICaT documentation. If you don't have one, use download_method='force_download' and determine the hash using helpers.hash_file(). (Default is ``None``) names_networkFiles (dict): Names of the files in the ROInet.zip file. If uncertain, leave as None. The dictionary should have the form: \n ``{'params': 'params.json', 'model': 'model.py', 'state_dict': 'ConvNext_tiny__1_0_unfrozen__simCLR.pth',}`` \n Where 'params' is the parameters used to train the network (usually a .json file), 'model' is the model definition (usually a .py file), and 'state_dict' are the weights of the network (usually a .pth file). (Default is ``None``) forward_pass_version (str): Version of the forward pass to use. Options are 'latent' (return the post-head output latents, use this for tracking), 'head' (return the output of the head layers, use this for classification), and 'base' (return the output of the base model). (Default is ``'latent'``) verbose (bool): If True, print out extra information. (Default is ``True``) """ def __init__( self, dir_networkFiles: str, device: str = 'cpu', download_method: str = 'check_local_first', download_url: str = 'https://osf.io/x3fd2/download', download_hash: dict = None, names_networkFiles: dict = None, forward_pass_version: str = 'latent', verbose: bool = True, ): ## Imports super().__init__() self._device = device self._verbose = verbose self._dir_networkFiles = dir_networkFiles self._download_url = download_url self._download_path_save = str(Path(self._dir_networkFiles).resolve() / 'ROInet.zip') fn_download = partial( helpers.download_file, path_save=self._download_path_save, hash_type='MD5', hash_hex=download_hash, mkdir=True, allow_overwrite=True, write_mode='wb', verbose=self._verbose, chunk_size=1024, ) ## Find or download network files if download_method == 'force_download': fn_download(url=self._download_url, check_local_first=False, check_hash=False) if download_method == 'check_local_first': # assert download_hash is not None, "if using download_method='check_local_first' download_hash cannot be None. Either determine the hash of the zip file or use download_method='force_download'." fn_download(url=self._download_url, check_local_first=True, check_hash=True) if download_method == 'force_local': # assert download_hash is not None, "if using download_method='force_local' download_hash cannot be None" assert Path(self._download_path_save).exists(), f"if using download_method='force_local' the network files must exist in {self._download_path_save}" fn_download(url=None, check_local_first=True, check_hash=True) ## Extract network files from zip paths_extracted = helpers.extract_zip( path_zip=self._download_path_save, path_extract=self._dir_networkFiles, verbose=self._verbose, ) ## Find network files if names_networkFiles is None: names_networkFiles = { 'params': 'params.json', 'model': 'model.py', 'state_dict': '.pth', } paths_networkFiles = {} paths_networkFiles['params'] = [p for p in paths_extracted if names_networkFiles['params'] in str(Path(p).name)][0] paths_networkFiles['model'] = [p for p in paths_extracted if names_networkFiles['model'] in str(Path(p).name)][0] paths_networkFiles['state_dict'] = [p for p in paths_extracted if names_networkFiles['state_dict'] in str(Path(p).name)][0] ## Import network files sys.path.append(str(Path(paths_networkFiles['model']).parent.resolve())) import model print(f"Imported model from {paths_networkFiles['model']}") if self._verbose else None with open(paths_networkFiles['params']) as f: self.params_model = json.load(f) print(f"Loaded params_model from {paths_networkFiles['params']}") if self._verbose else None self.net = model.make_model(fwd_version=forward_pass_version, **self.params_model) print(f"Generated network using params_model") if self._verbose else None ## Prep network and load state_dict for param in self.net.parameters(): param.requires_grad = False self.net.eval() self.net.load_state_dict(torch.load(paths_networkFiles['state_dict'], map_location=torch.device(self._device))) print(f'Loaded state_dict into network from {paths_networkFiles["state_dict"]}') if self._verbose else None self.net = self.net.to(self._device) print(f'Loaded network onto device {self._device}') if self._verbose else None
[docs] def generate_dataloader( self, ROI_images: List[np.ndarray], um_per_pixel: float = 1.0, nan_to_num: bool = True, nan_to_num_val: float = 0.0, pref_plot: bool = False, batchSize_dataloader: int = 8, pinMemory_dataloader: bool = True, numWorkers_dataloader: int = -1, persistentWorkers_dataloader: bool = True, prefetchFactor_dataloader: int = 2, transforms: Optional[Callable] = None, img_size_out: Tuple[int, int] = (224, 224), jit_script_transforms: bool = False, ): """ Generates a PyTorch DataLoader for a list of Region of Interest (ROI) images. Performs preprocessing such as rescaling, normalization, and resizing. Args: ROI_images (List[np.ndarray]): The ROI images to use for the dataloader. List of arrays, each array corresponds to a session and is of shape *(n_rois, height, width)*. um_per_pixel (float): The number of microns per pixel. Used to rescale the ROI images to the same size as the network input. (Default is *1.0*) nan_to_num (bool): Whether to replace NaNs with a specific value. (Default is ``True``) nan_to_num_val (float): The value to replace NaNs with. (Default is *0.0*) pref_plot (bool): If ``True``, plots the sizes of the ROI images before and after normalization. (Default is ``False``) batchSize_dataloader (int): The batch size to use for the DataLoader. (Default is *8*) pinMemory_dataloader (bool): If ``True``, pins the memory of the DataLoader, as per PyTorch's best practices. (Default is ``True``) numWorkers_dataloader (int): The number of worker processes for data loading. (Default is *-1*) persistentWorkers_dataloader (bool): If ``True``, uses persistent worker processes. (Default is ``True``) prefetchFactor_dataloader (int): The prefetch factor for data loading. (Default is *2*) transforms (Optional[Callable]): The transforms to use for the DataLoader. If ``None``, the function will only scale dynamic range (to 0-1), resize (to img_size_out dimensions), and tile channels (to 3) as a minimum to pass images through the network. (Default is ``None``) img_size_out (Tuple[int, int]): The image output dimensions of DataLoader if transforms is ``None``. (Default is *(224, 224)*) jit_script_transforms (bool): If ``True``, converts the transforms pipeline into a TorchScript pipeline, potentially improving calculation speed but can cause problems with multiprocessing. (Default is ``False``) Returns: (np.ndarray): ROI_images (np.ndarray): The ROI images after normalization and resizing. Shape is *(n_sessions, n_rois, n_channels, height, width)*. Example: .. highlight:: python .. code-block:: python dataloader = generate_dataloader(ROI_images) """ ## Remove NaNs ### Check if any NaNs if np.any([np.any(np.isnan(roi)) for roi in ROI_images]): warnings.warn('ROICaT WARNING: NaNs detected. You should consider removing remove these before passing to the network. Using nan_to_num arguments.') if np.any([np.any(np.isinf(roi)) for roi in ROI_images]): warnings.warn('ROICaT WARNING: Infs detected. You should consider removing these before passing to the network.') ## Check if any images in any of the sessions are all zeros if np.any([np.any(np.all(rois==0, axis=(1,2))) for rois in ROI_images]): warnings.warn('ROICaT WARNING: Image(s) with all zeros detected. These can pass through the network, but may give weird results.') if nan_to_num: ROI_images = [np.nan_to_num(rois, nan=nan_to_num_val) for rois in ROI_images] if numWorkers_dataloader == -1: numWorkers_dataloader = mp.cpu_count() print('Starting: resizing ROIs') if self._verbose else None sf_rs = [self.resize_ROIs(rois, um_per_pixel) for rois in ROI_images] ROI_images_cat = np.concatenate(ROI_images, axis=0) ROI_images_rs = np.concatenate(sf_rs, axis=0) print('Completed: resizing ROIs') if self._verbose else None if pref_plot: fig, axs = plt.subplots(2,1, figsize=(7,10)) axs[0].plot(np.mean(ROI_images_cat > 0, axis=(1,2))) axs[0].plot(scipy.signal.savgol_filter(np.mean(ROI_images_cat > 0, axis=(1,2)), 501, 3)) axs[0].set_xlabel('ROI number'); axs[0].set_ylabel('mean npix'); axs[0].set_title('ROI sizes raw') axs[1].plot(np.mean(ROI_images_rs > 0, axis=(1,2))) axs[1].plot(scipy.signal.savgol_filter(np.mean(ROI_images_rs > 0, axis=(1,2)), 501, 3)) axs[1].set_xlabel('ROI number'); axs[1].set_ylabel('mean npix'); axs[1].set_title('ROI sizes resized') if transforms is None: transforms = torch.nn.Sequential( ScaleDynamicRange(scaler_bounds=(0,1)), torchvision.transforms.Resize( size=img_size_out, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True, ), TileChannels(dim=0, n_channels=3), ) if jit_script_transforms: if numWorkers_dataloader > 0: warnings.warn("\n\nWarning: Converting transforms to a jit-based script has been known to cause issues on Windows when numWorkers_dataloader > 0. If self.generate_latents() raises an Exception similar to 'Tried to serialize object __torch__.torch.nn.modules.container.Sequential which does not have a __getstate__ method defined!' consider setting numWorkers_dataloader=0 or jit_script_transforms=False.\n") self.transforms = torch.jit.script(transforms) else: self.transforms = transforms print(f'Defined image transformations: {transforms}') if self._verbose else None self.dataset = dataset_simCLR( X=torch.as_tensor(ROI_images_rs, device='cpu', dtype=torch.float32), y=torch.as_tensor(torch.zeros(ROI_images_rs.shape[0]), device='cpu', dtype=torch.float32), n_transforms=1, transform=self.transforms, DEVICE='cpu', dtype_X=torch.float32, ) print(f'Defined dataset') if self._verbose else None self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=batchSize_dataloader, shuffle=False, drop_last=False, pin_memory=pinMemory_dataloader, num_workers=numWorkers_dataloader, persistent_workers=persistentWorkers_dataloader, prefetch_factor=prefetchFactor_dataloader, ) print(f'Defined dataloader') if self._verbose else None self.ROI_images_rs = ROI_images_rs return ROI_images_rs
[docs] @classmethod def resize_ROIs( cls, ROI_images: np.ndarray, # Array of shape (n_rois, height, width) um_per_pixel: float, ) -> np.ndarray: """ Resizes the ROI (Region of Interest) images to prepare them for pass through network. Args: ROI_images (np.ndarray): The ROI images to resize. Array of shape *(n_rois, height, width)*. um_per_pixel (float): The number of microns per pixel. This value is used to rescale the ROI images so that they occupy a standard region of the image frame. Returns: (np.ndarray): ROI_images_rs (np.ndarray): The resized ROI images. """ scale_forRS = 0.7 * um_per_pixel ## hardcoded for now sorry return np.stack([resize_affine(img, scale=scale_forRS, clamp_range=True) for img in ROI_images], axis=0)
[docs] def generate_latents(self) -> torch.Tensor: """ Passes the data in the dataloader through the network and generates latents. Returns: (torch.Tensor): latents (torch.Tensor): Latents for each ROI (Region of Interest). """ if hasattr(self, 'dataloader') == False: raise Exception('dataloader not defined. Call generate_dataloader() first.') print(f'starting: running data through network') self.latents = torch.cat([self.net(data[0][0].to(self._device)).detach() for data in tqdm(self.dataloader, mininterval=5)], dim=0).cpu() print(f'completed: running data through network') gc.collect() torch.cuda.empty_cache() gc.collect() torch.cuda.empty_cache() return self.latents
################################### ########### RESIZING ############## ###################################
[docs] def resize_affine( img: np.ndarray, scale: float, clamp_range: bool = False, ) -> np.ndarray: """ Resizes an image using an affine transformation, scaled by a factor. Args: img (np.ndarray): The input image to resize. Shape: *(H, W)* scale (float): The scale factor to apply for resizing. clamp_range (bool): If ``True``, the image will be clamped to the range [min(img), max(img)] to prevent interpolation from extending outside of the image's range. (Default is ``False``) Returns: (np.ndarray): resized_image (np.ndarray): The resized image. """ img_rs = np.array(torchvision.transforms.functional.affine( img=PIL.Image.fromarray(img), angle=0, translate=[0,0], shear=0, scale=scale, interpolation=torchvision.transforms.InterpolationMode.BICUBIC )) if clamp_range: clamp_high = img.max() clamp_low = img.min() img_rs[img_rs>clamp_high] = clamp_high img_rs[img_rs<clamp_low] = clamp_low return img_rs
[docs] def resize_images( imgs: np.ndarray, scale: float, clamp_range: bool = False, ) -> np.ndarray: """ Resizes images using an affine transformation, scaled by a factor. Uses torch.nn.functional.grid_sample to perform the resizing. Args: imgs (np.ndarray): The input images to resize. Shape: *(N, H, W)* scale (float): The scale factor to apply for resizing. clamp_range (bool): If ``True``, the image will be clamped to the range [min(img), max(img)] to prevent interpolation from extending outside of the image's range. (Default is ``False``) Returns: (np.ndarray): resized_images (np.ndarray): The resized images. Shape: *(N, H, W)* """ imgs = imgs[None, ...] if imgs.ndim == 2 else imgs imgs_rs = img_size = imgs.shape[1:] meshgrid_out = torch.stack(torch.meshgrid(torch.linspace(-1, 1, img_size[0]), torch.linspace(-1, 1, img_size[1]), indexing='xy'), dim=-1) imgs_rs = torch.nn.functional.grid_sample( input=torch.as_tensor(imgs)[None, ...], grid=meshgrid_out[None, ...] / scale, mode='bicubic', padding_mode='zeros', align_corners=True, )[0].numpy() if clamp_range: imgs_rs = np.clip(imgs_rs, a_min=imgs.min(axis=(1,2), keepdims=True), a_max=imgs.max(axis=(1,2), keepdims=True)) return imgs_rs
[docs] def resize_affine2( imgs: np.ndarray, scale: float, clamp_range: bool = False, ) -> np.ndarray: """ Resizes an image using an affine transformation, scaled by a factor. Args: img (np.ndarray): The input images to resize. Shape: *(N, H, W)* scale (float): The scale factor to apply for resizing. clamp_range (bool): If ``True``, the image will be clamped to the range [min(img), max(img)] to prevent interpolation from extending outside of the image's range. (Default is ``False``) Returns: (np.ndarray): resized_image (np.ndarray): The resized image. """ img_rs = np.array(torchvision.transforms.functional.affine( img=PIL.Image.fromarray(imgs.transpose(1,2,0)), angle=0, translate=[0,0], shear=0, scale=scale, interpolation=torchvision.transforms.InterpolationMode.BICUBIC )).transpose(2,0,1) if clamp_range: imgs_rs = np.clip(imgs_rs, a_min=imgs.min(axis=(1,2), keepdims=True), a_max=imgs.max(axis=(1,2), keepdims=True)) return img_rs
################################### ########### FROM GRC ############## ###################################
[docs] class TileChannels(Module): """ Expand dimension dim in X_in and tile to be N channels. RH 2021 """ def __init__(self, dim=0, n_channels=3): """ Initializes the class. Args: dim (int): The dimension to tile. n_channels (int): The number of channels to tile to. """ super().__init__() self.dim = dim self.n_channels = n_channels
[docs] def forward(self, tensor): dims = [1]*len(tensor.shape) dims[self.dim] = self.n_channels return torch.tile(tensor, dims)
def __repr__(self): return f"TileChannels(dim={self.dim})"
[docs] class Unsqueeze(Module): """ Expand dimension dim in X_in and tile to be N channels. JZ 2023 """ def __init__(self, dim=0): """ Initializes the class. Args: dim (int): The dimension to tile. n_channels (int): The number of channels to tile to. """ super().__init__() self.dim = dim
[docs] def forward(self, tensor): return torch.unsqueeze(tensor, self.dim)
def __repr__(self): return f"Unsqueeze(dim={self.dim})"
[docs] class ScaleDynamicRange(Module): """ Min-max scaling of the input tensor. RH 2021 """ def __init__(self, scaler_bounds=(0,1), epsilon=1e-9): """ Initializes the class. Args: scaler_bounds (tuple): The bounds of how much to multiply the image by prior to adding the Poisson noise. epsilon (float): Value to add to the denominator when normalizing. """ super().__init__() self.bounds = scaler_bounds self.range = scaler_bounds[1] - scaler_bounds[0] self.epsilon = epsilon
[docs] def forward(self, tensor): tensor_minSub = tensor - tensor.min() return tensor_minSub * (self.range / (tensor_minSub.max()+self.epsilon))
def __repr__(self): return f"ScaleDynamicRange(scaler_bounds={self.bounds})"
[docs] class dataset_simCLR(Dataset): """ Args: X (Union[torch.Tensor, np.array, List[float]]): Images. Expected shape: *(n_samples, height, width)*. Currently expects no channel dimension. If/when it exists, then shape should be *(n_samples, n_channels, height, width)*. y (Union[torch.Tensor, np.array, List[int]]): Labels. Shape: *(n_samples)*. n_transforms (int): Number of transformations to apply to each image. Should be >= 1. (Default is ``2``) transform (Optional[Callable]): Optional transform to be applied on a sample. See torchvision.transforms for more information. Can use torch.nn.Sequential(a, bunch, of, transforms,) or other methods from torchvision.transforms. \n * If not ``None``: Transform(s) are applied to each image and the output shape of X_sample_transformed for __getitem__ will be *(n_samples, n_transforms, n_channels, height, width)*. * If ``None``: No transform is applied and output shape of X_sample_trasformed for __getitem__ will be *(n_samples, n_channels, height, width)* (which is missing the n_transforms dimension). \n (Default is ``None``) DEVICE (str): Device on which the data will be stored and transformed. Best to leave this as 'cpu' and do .to(DEVICE) on the data for the training loop. (Default is ``'cpu'``) dtype_X (torch.dtype): Data type of X. (Default is ``torch.float32``) dtype_y (torch.dtype): Data type of y. (Default is ``torch.int64``) temp_uncetainty (float): Temperture term applied to the CrossEntropyLoss input. (Default is ``1.0`` for no change) Example: .. highlight:: python .. code-block:: python transforms = torch.nn.Sequential( torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.GaussianBlur( 5, sigma=(0.01, 1.) ), torchvision.transforms.RandomPerspective( distortion_scale=0.6, p=1, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, fill=0 ), torchvision.transforms.RandomAffine( degrees=(-180,180), translate=(0.4, 0.4), scale=(0.7, 1.7), shear=(-20, 20, -20, 20), interpolation=torchvision.transforms.InterpolationMode.BILINEAR, fill=0, fillcolor=None, resample=None ), ) scripted_transforms = torch.jit.script(transforms) dataset = dataset_simCLR( torch.tensor(images), labels, n_transforms=2, transform=scripted_transforms, DEVICE='cpu', dtype_X=torch.float32, dtype_y=torch.int64) dataloader = torch.utils.data.DataLoader( dataset, batch_size=64, shuffle=True, drop_last=True, pin_memory=False, num_workers=0) """ def __init__( self, X: Union[torch.Tensor, np.array, List[float]], y: Union[torch.Tensor, np.array, List[int]], n_transforms: int = 2, transform: Optional[Callable] = None, DEVICE: str = 'cpu', dtype_X: torch.dtype = torch.float32, dtype_y: torch.dtype = torch.int64, ): """ Initializes the dataset_simCLR object with the given images, labels, and optional settings. """ self.X = torch.as_tensor(X, dtype=dtype_X, device=DEVICE) # first dim will be subsampled from. Shape: (n_samples, n_channels, height, width) self.X = self.X[:,None,...] self.y = torch.as_tensor(y, dtype=dtype_y, device=DEVICE) # first dim will be subsampled from. self.idx = torch.arange(self.X.shape[0], device=DEVICE) self.n_samples = self.X.shape[0] self.transform = transform self.n_transforms = n_transforms if X.shape[0] != y.shape[0]: raise ValueError('RH Error: X and y must have same first dimension shape')
[docs] def tile_channels( self, X_in: Union[torch.Tensor, np.ndarray], dim: int = -3, ) -> Union[torch.Tensor, np.ndarray]: """ Expand dimension dim in X_in and tile to be 3 channels. Args: X_in (torch.Tensor or np.ndarray): Input image with shape: *(n_channels==1, height, width)* dim (int): Dimension to expand. (Default is ``-3``) Returns: (torch.Tensor or np.ndarray): X_out (torch.Tensor or np.ndarray): Output image with shape: *(n_channels==3, height, width)* """ dims = [1]*len(X_in.shape) dims[dim] = 3 return torch.tile(X_in, dims)
def __len__(self): """ Get the total number of samples in the dataset. Returns: (int): n_samples (int): The total number of samples. """ return self.n_samples def __getitem__( self, idx: int, ) -> Tuple[Union[torch.Tensor, np.ndarray], int, int, int]: """ Retrieves and transforms a sample. Args: idx (int): Index of the sample to retrieve. Returns: (Tuple): tuple containing: X_sample_transformed (torch.Tensor or np.ndarray): Transformed sample(s). Shape: * If transform is ``None``: *(batch_size, n_channels, height, width)* * If transform is not ``None``: *(n_transforms, batch_size, n_channels, height, width)* y_sample (int): Label of the sample. idx_sample (int): Index of the sample. sample_weight (int): Weight of the sample. Always 1. """ y_sample = self.y[idx] idx_sample = self.idx[idx] sample_weight = 1 X_sample_transformed = [] if self.transform is not None: for ii in range(self.n_transforms): X_transformed = self.transform(self.X[idx_sample]) X_sample_transformed.append(X_transformed) else: X_sample_transformed = self.tile_channels(self.X[idx_sample], dim=-3) return X_sample_transformed, y_sample, idx_sample, sample_weight