import pathlib
from pathlib import Path
import copy
import warnings
from typing import List, Optional, Union, Tuple, Dict, Any, Callable, Iterable
import numpy as np
from tqdm.auto import tqdm
import scipy.sparse
import sparse
from . import helpers, util
"""
Classes for importing data into the roicat package.
Conventions:
- Data_roicat is the super class for all data objects.
- Data_roicat can be used to make a custom data object.
- Subclasses like Data_suite2p and Data_caiman should be used
to import data from files and convert it to a
Data_roicat ingestable format.
- Subclass import methods should be functions that return
properly formatted data ready for the superclass to ingest.
- Avoid directly setting attributes. Try to always use a
.set_attribute() method.
- Subclass import methods should operate at the multi-session
level. That is, they should take in lists of objects
corresponding to multiple sessions.
- Subclasses should be able to initialize classification and
tracking independently. Minimize interdependent attributes.
- Users should have flexibility in the following switch-cases:
- FOV_images:
- From file
- From object
- Only specify FOV_height and FOV_width
- Only default to importing from file if the file is deriving
from a standardized format (ie suite2p or caiman). Do not
require standardization for custom data objects like class
labels.
"""
############################################################################################################################
####################################### SUPER CLASS FOR ALL DATA OBJECTS ###################################################
############################################################################################################################
[docs]
class Data_roicat(util.ROICaT_Module):
"""
Superclass for all data objects. Can be used as a template for creating
custom data objects. RH 2022
Args:
verbose (bool):
Determines whether to print status updates. (Default is ``True``)
Attributes:
type (object):
The type of the data object. Set by the subclass.
n_sessions (int):
The number of imaging sessions.
n_roi (int):
The number of ROIs in each session.
n_roi_total (int):
The total number of ROIs across all sessions.
FOV_height (int):
The height of the field of view in pixels.
FOV_width (int):
The width of the field of view in pixels.
FOV_images (List[np.ndarray]):
A list of numpy arrays, each with shape *(FOV_height, FOV_width)*.
Each element represents an imaging session.
ROI_images (List[np.ndarray]):
A list of numpy arrays, each with shape *(n_roi, height, width)*.
Each element represents an imaging session and each element of the
numpy array (first dimension) is an ROI.
spatialFootprints (List[object]):
A list of scipy.sparse.csr_array objects, each with shape *(n_roi,
FOV_height \* FOV_width)*. Each element represents an imaging session.
class_labels_raw (List[np.ndarray]):
A list of numpy arrays, each with shape *(n_roi,)*, where each
element is an integer. Each element of the list is an imaging
session and each element of the numpy array is a class label.
class_labels_index (List[np.ndarray]):
A list of numpy arrays, each with shape *(n_roi,)*, where each
element is an integer. Each element of the list is an imaging
session and each element of the numpy array is the index of the
class label obtained from passing the raw class label through
np.unique(\*, return_inverse=True).
um_per_pixel (Union[float, List[float]]):
The conversion factor from pixels to microns. This is used to scale
the ROI_images to a common size. Should either be a float or a list
of floats, one for each session.
session_bool (np.ndarray):
A boolean matrix with shape *(n_roi_total, n_sessions)*. Each
element is ``True`` if the ROI is present in the session.
"""
def __init__(
self,
verbose: bool = True,
) -> None:
"""
Initializes the Data_roicat object with the specified verbosity.
"""
## Imports
super().__init__()
self._verbose = verbose
#########################################################
################# CLASSIFICATION ########################
#########################################################
[docs]
def set_ROI_images(
self,
ROI_images: List[np.ndarray],
um_per_pixel: Optional[List[float]] = None,
) -> None:
"""
Imports ROI images into the class. Images are expected to be formatted
as a list of numpy arrays. Each element is an imaging session. Each
element is a numpy array of shape *(n_roi, FOV_height, FOV_width)*. This
method will set the attributes: self.ROI_images, self.n_roi,
self.n_roi_total, self.n_sessions. If any of these attributes are
already set, it will verify the new values match the existing ones.
Args:
ROI_images (List[np.ndarray]):
List of numpy arrays each of shape *(n_roi, FOV_height,
FOV_width)*.
um_per_pixel (Union[float, List[float]]):
The conversion factor from pixels to microns. This is used to scale
the ROI_images to a common size. Should either be a float or a list
of floats, one for each session.
"""
## Store parameter (but not data) args as attributes
self.params['set_ROI_images'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'um_per_pixel',
],
)
print(f"Starting: Importing ROI images") if self._verbose else None
## Check the validity of the inputs
ROI_images = self._fix_ROI_images(ROI_images=ROI_images)
## Warn if no um_per_pixel is provided
if um_per_pixel is None:
## Check if it is already set
if hasattr(self, 'um_per_pixel'):
um_per_pixel = self.um_per_pixel
warnings.warn("RH WARNING: No um_per_pixel provided. We recommend making an educated guess. Assuming 1.0 um per pixel for each session. This will affect the embedding results.")
um_per_pixel = [1.0,] * len(ROI_images)
um_per_pixel = self._fix_um_per_pixel(um_per_pixel=um_per_pixel, n_sessions=len(ROI_images))
## Define some variables
n_sessions = len(ROI_images)
n_roi = [roi.shape[0] for roi in ROI_images]
n_roi_total = int(sum(n_roi))
## Check that attributes match if they already exist as an attribute
if hasattr(self, 'n_sessions'):
assert self.n_sessions == n_sessions, f"n_sessions is already set to {self.n_sessions} but new value is {n_sessions}"
if hasattr(self, 'n_roi'):
assert self.n_roi == n_roi, f"n_roi is already set to {self.n_roi} but new value is {n_roi}"
if hasattr(self, 'n_roi_total'):
assert self.n_roi_total == n_roi_total, f"n_roi_total is already set to {self.n_roi_total} but new value is {n_roi_total}"
## Set attributes
self.n_sessions = n_sessions
self.n_roi = n_roi
self.n_roi_total = n_roi_total
self.ROI_images = ROI_images
self.um_per_pixel = um_per_pixel
print(f"Completed: Imported {n_sessions} sessions. Each session has {n_roi} ROIs. Total number of ROIs is {n_roi_total}. The um_per_pixel is {um_per_pixel} um per pixel.") if self._verbose else None
[docs]
def set_class_labels(
self,
labels: Optional[List[Union[np.ndarray, List[Union[int, str, float]]]]] = None,
path_labels: Optional[Union[str, List[str]]] = None,
n_classes: Optional[int] = None,
) -> None:
"""
Imports class labels into the class.
* labels are expected to be formatted as a list of arrays. The outer
list should have length equal to the number of sessions (n_sessions).
Each element in the list is either a 1D array or list of integers or
strings and should have length equal to the number of ROIs in that
session (n_roi). Each element in the array or list is the class label
for the corresponding ROI and can be a number or a string.
List[Union[np.ndarray, List[Union[int, str, float]]]]. \n
Args:
labels (Optional[List[Union[np.ndarray, List[Union[int, str, float]]]]]): \n
* If ``None``: path_labels must be specified.
* Else: ``labels`` are expected to be a list of arrays. The
outer list should have length equal to the number of sessions
(``n_sessions``). Each element in the list is either a 1D
array or list of integers or strings and should have length
equal to the number of ROIs in that session (``n_roi``). Each
element in the array or list is the class label for the
corresponding ROI and can be a number or a string.
List[Union[np.ndarray, List[Union[int, str, float]]]]. \n
path_labels (Optional[Union[str, List[str]]]): \n
* If ``None``: labels must be specified.
* Else: ``path_labels`` is expected to be a list of strings. Each
element in the list is a path to a file containing the class
labels. The outer list should have length equal to
the number of sessions (``n_sessions``). Each file should be a
json file containing a list of integers or strings corresponding
to the class labels for each ROI in that session.
n_classes (Optional[int]):
Number of classes. If not provided, it will be inferred from the
class labels. (Default is ``None``)
"""
## Store parameter (but not data) args as attributes
self.params['set_class_labels'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'path_labels',
'n_classes',
],
)
print(f"Starting: Importing class labels") if self._verbose else None
## Check inputs
if labels is not None:
assert isinstance(labels, list), f"labels should be a list. It is a {type(labels)}"
## make sure all elements are numpy arrays or lists
assert all([isinstance(l, (np.ndarray, list)) for l in labels]), f"labels should be a list of numpy arrays or lists of integers or strings. First element of list is of type {type(labels[0])}"
## convert lists to numpy arrays
labels_raw = [np.array(l, dtype=str) if isinstance(l, list) else l for l in labels]
## make sure all elements are 1D
assert all([l.ndim==1 for l in labels_raw]), f"labels should be a list of 1D numpy arrays or lists of integers or strings. First element of list is of shape {labels[0].shape}"
elif path_labels is not None:
## It should be a csv file (or list of files) with the first column as the index and second column as the class label
assert isinstance(path_labels, (str, list)), f"path_labels should be a string or a list of strings. It is a {type(path_labels)}"
if isinstance(path_labels, str):
path_labels = [path_labels]
## make sure all elements are strings
assert all([isinstance(l, str) for l in path_labels]), f"path_labels should be a list of strings. First element of list is of type {type(path_labels[0])}"
## make sure all files exist
assert all([Path(l).exists() for l in path_labels]), f"Files in path_labels do not exist. Please check the paths."
## make sure all files are json files containing lists of integers or strings or floats
assert all([Path(l).suffix in ['.json',] for l in path_labels]), f"Files in path_labels should be json files. Please check the file extensions."
## load the files
labels_raw = [np.array(helpers.json_load(l), dtype=str) for l in path_labels]
else:
raise ValueError(f"Either labels or path_labels must be specified.")
## convert lists to numpy arrays of unique integers
labels_cat = np.concatenate(labels_raw, axis=0)
unique_class_labels, labels_cat_squeezeInt = np.unique(labels_cat, return_inverse=True)
n_classes = len(unique_class_labels)
n_class_labels = [lbls.shape[0] for lbls in labels_raw]
n_class_labels_total = sum(n_class_labels)
n_sessions = len(labels_raw)
class_labels_squeezeInt = util.labels_to_labelsBySession(labels=labels_cat_squeezeInt, n_roi_bySession=n_class_labels)
## Set attributes
self.class_labels_raw = labels_raw
self.class_labels_index = class_labels_squeezeInt
self.n_classes = n_classes
self.n_class_labels = n_class_labels
self.n_class_labels_total = n_class_labels_total
self.unique_class_labels = unique_class_labels
## Check if label data shapes match ROI_image data shapes
self._checkValidity_classLabels_vs_ROIImages()
print(f"Completed: Imported labels for {n_sessions} sessions. Each session has {n_class_labels} class labels. Total number of class labels is {n_class_labels_total}.") if self._verbose else None
@classmethod
def _fix_um_per_pixel(
cls,
um_per_pixel: Union[float, List[float]],
n_sessions: int,
) -> List[float]:
if isinstance(um_per_pixel, float):
um_per_pixel = [um_per_pixel,] * n_sessions
elif isinstance(um_per_pixel, list):
assert all([isinstance(ump, float) for ump in um_per_pixel]), f"um_per_pixel should be a float or a list of floats. First element of list is of type {type(um_per_pixel[0])}"
um_per_pixel = [float(ump) for ump in um_per_pixel]
else:
raise ValueError(f"um_per_pixel should be a float or a list of floats. It is a {type(um_per_pixel)}")
assert len(um_per_pixel) == n_sessions, f"um_per_pixel should be a float or a list of floats of length equal to the number of sessions. It is of length {len(um_per_pixel)} and there are {n_sessions} sessions."
return um_per_pixel
@classmethod
def _fix_ROI_images(
cls,
ROI_images: List[np.ndarray],
) -> List[np.ndarray]:
### Check ROI_images
if isinstance(ROI_images, np.ndarray):
print("RH WARNING: ROI_images is a numpy array. Assuming n_sessions==1 and wrapping array in a list.")
ROI_images = [ROI_images,]
assert isinstance(ROI_images, list), f"ROI_images should be a list. It is a {type(ROI_images)}"
assert all([isinstance(roi, np.ndarray) for roi in ROI_images]), f"ROI_images should be a list of numpy arrays. First element of list is of type {type(ROI_images[0])}"
assert all([roi.ndim==3 for roi in ROI_images]), f"ROI_images should be a list of numpy arrays of shape (n_roi, height, width). First element of list is of shape {ROI_images[0].shape}"
### Assert that all the ROI heights and widths are the same
assert all([tuple(roi.shape[1:]) == tuple(ROI_images[0].shape[1:]) for roi in ROI_images]), f"ROI_images should be a list of numpy arrays of shape (n_roi, height, width). All elements should have the same height and width as the first element."
return ROI_images
def _checkValidity_classLabels_vs_ROIImages(
self,
verbose: Optional[bool] = None,
) -> None:
"""
Checks that the class labels and the ROI images have the same number of
sessions and the same number of ROIs in each session.
Args:
verbose (Optional[bool]):
If ``None``, the verbosity level set in the class is used.
(Default is ``None``)
"""
if verbose is None:
verbose = self._verbose
## Check if class_labels and ROI_images exist
if not (hasattr(self, 'class_labels_index') and hasattr(self, 'ROI_images')):
print("Cannot check validity of class_labels_index and ROI_images because one or both do not exist as attributes.") if verbose else None
return False
## Check num sessions
n_sessions_classLabels = len(self.class_labels_index)
n_sessions_ROIImages = len(self.ROI_images)
assert n_sessions_classLabels == n_sessions_ROIImages, f"RH ERROR: Number of sessions (list elements) in class_labels_index ({n_sessions_classLabels}) does not match number of sessions (list elements) in ROI_images ({n_sessions_ROIImages})."
## Check num ROIs
n_ROIs_classLabels = [lbls.shape[0] for lbls in self.class_labels_index]
n_ROIs_ROIImages = [img.shape[0] for img in self.ROI_images]
assert all([l == r for l, r in zip(n_ROIs_classLabels, n_ROIs_ROIImages)]), f"RH ERROR: Number of ROIs in each session in class_labels_index ({n_ROIs_classLabels}) does not match number of ROIs in each session in ROI_images ({n_ROIs_ROIImages})."
print(f"Labels and ROI Images match in shapes: Class labels and ROI images have the same number of sessions and the same number of ROIs in each session.") if verbose else None
return True
#########################################################
#################### TRACKING ###########################
#########################################################
[docs]
def set_FOV_images(
self,
FOV_images: List[np.ndarray],
):
"""
Sets the **FOV_images** attribute.
Args:
FOV_images (List[np.ndarray]):
List of 2D **numpy.ndarray** objects, one for each session. Each
array should have shape *(FOV_height, FOV_width)*.
"""
## Store parameter (but not data) args as attributes
### Nothing to store
if isinstance(FOV_images, np.ndarray):
assert FOV_images.ndim == 3, f"RH ERROR: FOV_images must be a list of 2D numpy arrays."
FOV_images = [fov for fov in FOV_images]
## Check inputs
assert isinstance(FOV_images, list), f"RH ERROR: FOV_images must be a list."
assert all([isinstance(img, np.ndarray) for img in FOV_images]), f"RH ERROR: All elements in FOV_images must be numpy arrays."
assert all([img.ndim == 2 for img in FOV_images]), f"RH ERROR: All elements in FOV_images must be 2D numpy arrays."
assert all([img.shape[0] == FOV_images[0].shape[0] for img in FOV_images]), f"RH ERROR: All elements in FOV_images must have the same height and width."
assert all([img.shape[1] == FOV_images[0].shape[1] for img in FOV_images]), f"RH ERROR: All elements in FOV_images must have the same height and width."
## Set attributes
print("Setting FOV_images...") if self._verbose else None
self.FOV_images = [np.array(f, dtype=np.float32) for f in FOV_images]
self.FOV_height = int(FOV_images[0].shape[0])
self.FOV_width = int(FOV_images[0].shape[1])
## Get some variables
n_sessions = len(FOV_images)
## Check that attributes match if they already exist as an attribute
if hasattr(self, 'n_sessions'):
if self.n_sessions != n_sessions:
warnings.warn(f"RH WARNING: n_sessions is already set to {self.n_sessions} but new value is {n_sessions}")
print(f"Completed: Set FOV_images for {len(FOV_images)} sessions successfully.") if self._verbose else None
[docs]
def set_FOVHeightWidth(
self,
FOV_height: int,
FOV_width: int,
):
"""
Sets the **FOV_height** and **FOV_width** attributes.
Args:
FOV_height (int):
The height of the field of view (FOV) in pixels.
FOV_width (int):
The width of the field of view (FOV) in pixels.
"""
## Store parameter (but not data) args as attributes
self.params['set_FOVHeightWidth'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'FOV_height',
'FOV_width',
],
)
## Check inputs
assert isinstance(FOV_height, int), f"RH ERROR: FOV_height must be an integer."
assert isinstance(FOV_width, int), f"RH ERROR: FOV_width must be an integer."
## Set attributes
self.FOV_height = FOV_height
self.FOV_width = FOV_width
print(f"Completed: Set FOV_height and FOV_width successfully.") if self._verbose else None
def _checkValidity_spatialFootprints_and_FOVImages(
self,
verbose: Optional[bool] = None,
):
"""
Checks that **spatialFootprints** and **FOV_images** are compatible.
Args:
verbose (Optional[bool]):
If ``True``, outputs progress and error messages.
(Default is ``None``)
"""
if verbose is None:
verbose = self._verbose
if hasattr(self, 'spatialFootprints') and hasattr(self, 'FOV_images'):
assert len(self.spatialFootprints) == len(self.FOV_images), f"RH ERROR: spatialFootprints and FOV_images must have the same length."
assert all([sf.shape[1] == self.FOV_images[0].size for sf in self.spatialFootprints]), f"RH ERROR: spatialFootprints and FOV_images must have the same size."
print("Completed: spatialFootprints and FOV_images are compatible.") if verbose else None
return True
else:
print("Cannot check validity of spatialFootprints and FOV_images because one or both do not exist as attributes.") if verbose else None
return False
[docs]
def check_completeness(
self,
verbose: bool = True
) -> None:
"""
Checks which pipelines the data object is capable of running given the
attributes that have been set.
Args:
verbose (bool):
If ``True``, outputs progress and error messages. (Default is
``True``)
"""
completeness = {}
keys_classification_inference = ['ROI_images', 'um_per_pixel']
keys_classification_training = ['ROI_images', 'um_per_pixel', 'class_labels_index']
keys_tracking = ['ROI_images', 'um_per_pixel', 'spatialFootprints', 'FOV_images']
## Check classification inference:
### ROI_images, um_per_pixel
if all([hasattr(self, key) for key in keys_classification_inference]):
completeness['classification_inference'] = True
else:
print(f"RH WARNING: Classification-Inference incomplete because following attributes are missing: {[key for key in keys_classification_inference if not hasattr(self, key)]}") if verbose else None
completeness['classification_inference'] = False
## Check classification training:
### ROI_images, um_per_pixel, class_labels_index
if all([hasattr(self, key) for key in keys_classification_training]):
completeness['classification_training'] = True
else:
print(f"RH WARNING: Classification-Training incomplete because the following attributes are missing: {[key for key in keys_classification_training if not hasattr(self, key)]}") if verbose else None
completeness['classification_training'] = False
## Check tracking:
### um_per_pixel, spatialFootprints, FOV_images
if all([hasattr(self, key) for key in keys_tracking]):
completeness['tracking'] = True
else:
print(f"RH WARNING: Tracking incomplete because the following attributes are missing: {[key for key in keys_tracking if not hasattr(self, key)]}") if verbose else None
completeness['tracking'] = False
self._checkValidity_classLabels_vs_ROIImages(verbose=verbose)
self._checkValidity_spatialFootprints_and_FOVImages(verbose=verbose)
## Print completeness
print(f"Data_roicat object completeness: {completeness}") if verbose else None
return completeness
def _make_session_bool(self) -> np.ndarray:
"""
Creates a boolean array where each row is a boolean vector indicating
which session(s) the ROI was present in. Uses the ``self.n_roi``
attribute to determine which rows belong to which session.
Returns:
np.ndarray:
self.session_bool (np.ndarray):
A boolean array where each row is a boolean vector
indicating which session(s) the ROI was present in. Shape:
*(n_roi_total, n_sessions)*
"""
## Check that n_roi is set
assert hasattr(self, 'n_roi'), f"RH ERROR: n_roi must be set before session_bool can be created."
## Check that n_roi is the correct length
assert len(self.n_roi) == self.n_sessions, f"RH ERROR: n_roi must be the same length as n_sessions."
## Check that n_roi_total is correct
assert sum(self.n_roi) == self.n_roi_total, f"RH ERROR: n_roi must sum to n_roi_total."
## Create session_bool
self.session_bool = util.make_session_bool(self.n_roi)
print(f"Completed: Created session_bool.") if self._verbose else None
return self.session_bool
def _make_spatialFootprintCentroids(
self,
method: str = 'centerOfMass'
) -> np.ndarray:
"""
Calculates the centroids of a sparse array of flattened spatial
footprints. The centroid position is calculated as the center of mass of
the ROI.
JZ, RH 2022
Args:
method (str):
Method to use to calculate the centroid. Either \n
* ``'centerOfMass'``: Calculates the centroid position as the
mean center of mass of the ROI.
* ``'median'``: Calculates the centroid position as the median
center of mass of the ROI. \n
(Default is ``'centerOfMass'``)
Returns:
(np.ndarray):
centroids (np.ndarray):
Centroids of the ROIs with shape *(2, n_roi)*. Consists of
(y, x) coordinates.
"""
## Check that sf is a list of csr sparse arrays
assert isinstance(self.spatialFootprints, list), f"RH ERROR: spatialFootprints must be a list of scipy.sparse.csr_array."
assert all([scipy.sparse.issparse(sf) and sf.format == 'csr' for sf in self.spatialFootprints]), f"RH ERROR: spatialFootprints must be a list of scipy.sparse CSR arrays."
## Check that FOV_height and FOV_width are set
assert hasattr(self, 'FOV_height') and hasattr(self, 'FOV_width'), f"RH ERROR: FOV_height and FOV_width must be set before centroids can be calculated."
## Check that sf is the correct shape
assert all([sf.shape[1] == self.FOV_height*self.FOV_width for sf in self.spatialFootprints]), f"RH ERROR: spatialFootprints must have shape (n_roi, FOV_height*FOV_width)."
## Check that centroid_method is set
assert method in ['centerOfMass', 'median'], f"RH ERROR: centroid_method must be one of ['centerOfMass', 'median']."
## Calculate centroids
sf = self.spatialFootprints
FOV_height, FOV_width = self.FOV_height, self.FOV_width
## Reshape sf to (n_roi, FOV_height, FOV_width)
sf_rs = [sparse.COO(s).reshape((s.shape[0], FOV_height, FOV_width), order='C') for s in sf]
## Calculate the sum of the weights along each axis
y_w, x_w = [s.sum(axis=2) for s in sf_rs], [s.sum(axis=1) for s in sf_rs]
## Calculate the centroids
if method == 'centerOfMass':
y_cent = [(((w*np.arange(w.shape[1]).reshape(1,-1))).sum(1)/(w.sum(1)+1e-12)).todense() for w in y_w]
x_cent = [(((w*np.arange(w.shape[1]).reshape(1,-1))).sum(1)/(w.sum(1)+1e-12)).todense() for w in x_w]
elif method == 'median':
y_cent = [((((w!=0)*np.arange(w.shape[1]).reshape(1,-1, order='C'))).todense()).astype(np.float32) for w in y_w]
y_cent = [np.ma.masked_array(w, mask=(w==0)).filled(np.nan) for w in y_cent]
y_cent = [np.nanmedian(w, axis=1) for w in y_cent]
x_cent = [((((w!=0)*np.arange(w.shape[1]).reshape(1,-1, order='C'))).todense()).astype(np.float32) for w in x_w]
x_cent = [np.ma.masked_array(w, mask=(w==0)).filled(np.nan) for w in x_cent]
x_cent = [np.nanmedian(w, axis=1) for w in x_cent]
## Round to nearest integer
y_cent = [np.round(h) for h in y_cent]
x_cent = [np.round(w) for w in x_cent]
## Concatenate and store
self.centroids = [np.stack([y, x], axis=1).astype(np.int64) for y, x in zip(y_cent, x_cent)]
print(f"Completed: Created centroids.") if self._verbose else None
[docs]
def remove_rois_by_classLabel(
self,
classLabel_to_keep: Optional[Union[int, List[int]]] = None,
classLabel_to_remove: Optional[Union[int, List[int]]] = None,
in_place: bool = True,
verbose: Optional[bool] = None,
):
"""
Removes ROIs based on their class label. Remakes all attributes that are
affected by the removal of
ROIs. This includes:
* spatialFootprints
* ROI_images
* centroids
* class_labels_raw
* class_labels_index
* session_bool
* all attributes related to above attributes
Args:
classLabel_to_keep (Optional[Union[int, List[int]]]):
Class label(s) to keep. If ``None``, ``classLabel_to_remove``
must be provided. The values should correspond to the values
seen in the ``self.class_labels_raw`` attribute.
classLabel_to_remove (Optional[Union[int, List[int]]]):
Class label(s) to remove. If ``None``, ``classLabel_to_keep`` must
be provided. The values should correspond to the values seen in
the ``self.class_labels_raw`` attribute.
in_place (bool):
If ``True``, the object is modified in place. A new object is
returned with the ROIs removed either way.
verbose (Optional[bool]):
Whether to print progress messages. If ``None``, the verbosity
level set in the class is used.
Returns:
self (Data_roicat):
The object with the ROIs removed.
"""
if verbose is None:
verbose = self._verbose
## Check that class labels are set
assert hasattr(self, 'class_labels_raw'), f"RH ERROR: class_labels_raw must be set before ROIs can be removed by class label. Use set_class_labels() to set class_labels_raw."
assert hasattr(self, 'unique_class_labels'), f"RH ERROR: unique_class_labels must be set before ROIs can be removed by class label. Use set_class_labels() to set unique_class_labels."
## Check that classLabel_to_keep or classLabel_to_remove is provided
assert (classLabel_to_keep is not None) or (classLabel_to_remove is not None), f"RH ERROR: Either classLabel_to_keep or classLabel_to_remove must be provided."
## Check that only one of classLabel_to_keep or classLabel_to_remove is provided
assert (classLabel_to_keep is None) or (classLabel_to_remove is None), f"RH ERROR: Only one of classLabel_to_keep or classLabel_to_remove can be provided."
## Check that classLabel_to_keep or classLabel_to_remove is an int or a list of ints
if isinstance(classLabel_to_keep, np.ndarray):
classLabel_to_keep = classLabel_to_keep.tolist()
if isinstance(classLabel_to_remove, np.ndarray):
classLabel_to_remove = classLabel_to_remove.tolist()
if isinstance(classLabel_to_keep, int):
classLabel_to_keep = [classLabel_to_keep,]
if isinstance(classLabel_to_remove, int):
classLabel_to_remove = [classLabel_to_remove,]
assert (classLabel_to_keep is None) or (isinstance(classLabel_to_keep, int) or (isinstance(classLabel_to_keep, list) and all([isinstance(cl, int) for cl in classLabel_to_keep]))), f"RH ERROR: classLabel_to_keep must be an int or a list of ints."
assert (classLabel_to_remove is None) or (isinstance(classLabel_to_remove, int) or (isinstance(classLabel_to_remove, list) and all([isinstance(cl, int) for cl in classLabel_to_remove]))), f"RH ERROR: classLabel_to_remove must be an int or a list of ints."
print(f"Starting: Removing ROIs based on class labels...") if verbose else None
## Get the class labels to keep
if classLabel_to_remove is not None:
classLabel_to_keep = [cl for cl in self.unique_class_labels if cl not in classLabel_to_remove]
print(f"Converted classLabel_to_remove: {classLabel_to_remove} to classLabel_to_keep: {classLabel_to_keep}.") if verbose else None
## Get the indices of the class labels to keep for each session
idx_keep = [np.where(np.isin(cl, classLabel_to_keep))[0] for cl in self.class_labels_raw]
## Make new object
print(f"Making new Data_roicat object with ROIs removed based on class labels...") if verbose else None
data_new = Data_roicat(verbose=verbose)
## Remove the ROIs
if hasattr(self, 'spatialFootprints'):
print(f"Removing ROIs from spatialFootprints...") if verbose else None
sf = [sf[idx_keep[ii]] for ii, sf in enumerate(self.spatialFootprints)]
data_new.set_spatialFootprints(
spatialFootprints=sf,
um_per_pixel=self.um_per_pixel,
)
if hasattr(self, 'centroids'):
print(f"Removing ROIs from centroids...") if verbose else None
c = [c[idx_keep[ii]] for ii, c in enumerate(self.centroids)]
data_new.centroids = c
print(f"Centroids removed.") if verbose else None
if hasattr(self, 'ROI_images'):
print(f"Removing ROIs from ROI_images...") if verbose else None
ri = [ri[idx_keep[ii]] for ii, ri in enumerate(self.ROI_images)]
data_new.set_ROI_images(
ROI_images=ri,
um_per_pixel=self.um_per_pixel,
)
if hasattr(self, 'class_labels_raw'):
print(f"Removing ROIs from class_labels...") if verbose else None
cl = [cl[idx_keep[ii]] for ii, cl in enumerate(self.class_labels_raw)]
data_new.set_class_labels(
labels=cl,
n_classes=self.n_classes,
)
if hasattr(self, 'session_bool'):
print(f"Recomputing session_bool...") if verbose else None
data_new._make_session_bool()
print(f"Completed: Removed ROIs based on class labels. New object created. Old n_roi_total={self.n_roi_total}, new n_roi_total={data_new.n_roi_total}. Old unique_class_labels={self.unique_class_labels}, new unique_class_labels={data_new.unique_class_labels}.") if verbose else None
if in_place:
## Replace self with new object
### Check to see if there are any attributes in the old object that are not in the new object
keys_old = set(self.__dict__.keys())
keys_new = set(data_new.__dict__.keys())
keys_missing = keys_old - keys_new
keys_extra = keys_new - keys_old
keys_in_both = keys_old.intersection(keys_new)
## Print old keys
print(f"Existing attributes that will persist in new data object: {keys_missing}.") if verbose else None
## Print intersection keys
print(f"Existing attributes that will be replaced in new data object: {keys_in_both}.") if verbose else None
if len(keys_extra) > 0:
warnings.warn(f"RH WARNING: The following attributes are in the new data object but not in the old data object: {keys_extra}. This is unexpected.")
self.__dict__.update(data_new.__dict__)
print(f"Performed in-place replacement of self with new data object.") if verbose else None
return self
return data_new
def __repr__(self):
n_sessions = len(self.spatialFootprints) if hasattr(self, 'spatialFootprints') and self.spatialFootprints is not None else 0
n_roi = self.n_roi if hasattr(self, 'n_roi') and self.n_roi is not None else []
total_roi = sum(n_roi) if n_roi else 0
return (
f"{self.__class__.__name__}("
f"n_sessions={n_sessions}, "
f"n_roi_total={total_roi}, "
f"n_roi_per_session={n_roi})"
)
[docs]
def import_from_dict(
self,
dict_load: Dict[str, Any],
) -> None:
"""
Imports attributes from a dictionary. This is useful if a dictionary
that can be serialized was saved.
Args:
dict_load (Dict[str, Any]):
Dictionary containing args to load. Format:
{'method': [arg1, arg2, ...], ...}
Note:
This method does not return anything. It modifies the object state
by importing attributes from the provided dictionary.
"""
## Go through each important attribute in Data_roicat and look for it in dict_load
methods = {
self.set_ROI_images: {
'ROI_images': 'ROI_images', ## 'arg_name': 'key_in_dict_load'
'um_per_pixel': 'um_per_pixel',
},
self.set_spatialFootprints: {
'spatialFootprints': 'spatialFootprints',
'um_per_pixel': 'um_per_pixel',
},
self.set_FOV_images: {
'FOV_images': 'FOV_images',
},
self.set_class_labels: {
'labels': 'class_labels_raw',
},
}
methodKeys_all = list(set(sum([list(args_keys.values()) for args_keys in methods.values()], [])))
## Set other attributes
for key, val in dict_load.items():
if key not in methodKeys_all:
setattr(self, key, val)
## Set attributes using methods
for method, args_keys in methods.items():
if all([key in dict_load for key in args_keys.values()]):
method(**{arg: dict_load[key] for arg, key in args_keys.items()})
else:
print(f"RH WARNING: Could not load attribute using method {method.__name__}. Keys {args_keys.values()} not found in dict_load.") if self._verbose else None
############################################################################################################################
############################## CUSTOM CLASSES FOR SUITE2P AND CAIMAN OUTPUT FILES ##########################################
############################################################################################################################
#########################################################
#################### DATA S2P ###########################
#########################################################
[docs]
class Data_suite2p(Data_roicat):
"""
Class for handling suite2p output files and data. In particular stat.npy and
ops.npy files. Imports FOV images and spatial footprints, and prepares ROI
images.
RH 2022
Args:
paths_statFiles (list of str or pathlib.Path):
List of paths to the stat.npy files. Elements should be one of: str,
pathlib.Path, list of str or list of pathlib.Path.
paths_opsFiles (list of str or pathlib.Path, optional):
List of paths to the ops.npy files. Elements should be one of: str,
pathlib.Path, list of str or list of pathlib.Path. Optional. Used to
get FOV_images, FOV_height, FOV_width, and shifts (if old matlab ops
file).
um_per_pixel (Union[float, List[float]]):
Resolution in micrometers per pixel of the imaging field of view.
The conversion factor from pixels to microns. This is used to scale
the ROI_images to a common size. Should either be a float or a list
of floats, one for each session.
new_or_old_suite2p (str):
Type of suite2p output files. Matlab=old, Python=new. Should be:
``'new'`` or ``'old'``.
out_height_width (tuple of int):
Height and width of output ROI images. These are the little images
of centered ROIs that are typically used for passing through the
neural net. Unless your ROIs are larger than the default size, it's
best to just leave it as default. Should be: *(int, int)* *(y, x)*.
type_meanImg (str):
Type of mean image to use. Should be: ``'meanImgE'`` or
``'meanImg'``.
FOV_images (np.ndarray, optional):
FOV images. Array of shape *(n_sessions, FOV_height, FOV_width)*.
Optional.
centroid_method (str):
Method for calculating the centroid of an ROI. Should be:
``'centerOfMass'`` or ``'median'``.
class_labels ((list of np.ndarray) or (list of str to paths) or None):
Optional. If ``None``, class labels are not set. If list of
np.ndarray, each element should be a 1D integer array of length
n_roi specifying the class label for each ROI. If list of str, each
element should be a path to a .npy file containing an array of
length n_roi specifying the class label for each ROI.
paths_iscell (str or pathlib.Path or list of str or list of pathlib.Path):
Optional. Paths to the iscell.npy files. Elements should be one of:
str, pathlib.Path, list of str or list of pathlib.Path. If provided,
the iscell.npy files are used to set the class labels. If not
provided, the class labels are set to None. An iscell.npy file is
assumed to be a 2D numpy array of shape *(n_ROIs, (iscell boolean,
probability float))*
FOV_height_width (tuple of int, optional):
FOV height and width. If ``None``, **paths_opsFiles** must be
provided to get FOV height and width.
verbose (bool):
If ``True``, prints results from each function.
"""
def __init__(
self,
paths_statFiles: Union[str, pathlib.Path, List[Union[str, pathlib.Path]]],
paths_opsFiles: Optional[Union[str, pathlib.Path, List[Union[str, pathlib.Path]]]] = None,
um_per_pixel: Union[float, List[float]] = 1.0,
new_or_old_suite2p: str = 'new',
out_height_width: Tuple[int, int] = (36, 36),
type_meanImg: str = 'meanImgE',
FOV_images: Optional[np.ndarray] = None,
centroid_method: str = 'centerOfMass',
class_labels: Optional[Union[List[np.ndarray], List[str], None]] = None,
paths_iscell: Optional[Union[str, pathlib.Path, List[Union[str, pathlib.Path]]]] = None,
FOV_height_width: Optional[Tuple[int, int]] = None,
verbose: bool = True,
):
"""
Initialize the Data_suite2p object.
"""
## Inherit from Data_roicat
super().__init__()
self.paths_stat = fix_paths(paths_statFiles)
self.paths_ops = fix_paths(paths_opsFiles) if paths_opsFiles is not None else None
self.n_sessions = len(self.paths_stat)
## Store parameter (but not data) args as attributes
self.params['__init__'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'paths_statFiles',
'paths_opsFiles',
'um_per_pixel',
'new_or_old_suite2p',
'out_height_width',
'type_meanImg',
'centroid_method',
'paths_iscell',
'FOV_height_width',
'verbose',
],
)
self._verbose = verbose
## shifts are applied to convert the 'old' matlab version of suite2p indexing (where there is an offset and its 1-indexed)
self.shifts = self._make_shifts(paths_ops=self.paths_ops, new_or_old_suite2p=new_or_old_suite2p)
## Import FOV images
### Assert only one of self.paths_ops, FOV_images, or FOV_height_width is provided
assert sum([self.paths_ops is not None, FOV_images is not None, FOV_height_width is not None]) == 1, "RH ERROR: One (and only one) of self.paths_ops, FOV_images, or FOV_height_width must be provided."
### Import FOV images if self.paths_ops or FOV_images is provided
if self.paths_ops is not None:
FOV_images = self.import_FOV_images(type_meanImg=type_meanImg)
### Set FOV height and width if FOV_height_width is provided
elif FOV_height_width is not None:
assert isinstance(FOV_height_width, tuple), "RH ERROR: FOV_height_width must be a tuple of length 2."
assert len(FOV_height_width) == 2, "RH ERROR: FOV_height_width must be a tuple of length 2."
assert all([isinstance(x, int) for x in FOV_height_width]), "RH ERROR: FOV_height_width must be a tuple of length 2 of integers."
self.set_FOVHeightWidth(FOV_height=FOV_height_width[0], FOV_width=FOV_height_width[1])
self.set_FOV_images(FOV_images=FOV_images) if FOV_images is not None else None
## Import spatial footprints
spatialFootprints = self.import_spatialFootprints()
self.set_spatialFootprints(spatialFootprints=spatialFootprints, um_per_pixel=um_per_pixel)
## Make spatial footprint centroids
self._make_spatialFootprintCentroids(method=centroid_method)
## Transform spatial footprints to ROI images
self.transform_spatialFootprints_to_ROIImages(out_height_width=out_height_width)
## Make class labels
if class_labels is not None:
self.set_class_labels(labels=class_labels)
elif paths_iscell is not None:
self.set_class_labels(labels=[np.load(path)[:,0].astype(np.int64) for path in fix_paths(paths_iscell)])
[docs]
def import_FOV_images(
self,
type_meanImg: str = 'meanImgE',
) -> List[np.ndarray]:
"""
Imports the FOV images from ops files or user defined image arrays.
Args:
type_meanImg (str):
Type of the mean image. References the key in the ops.npy file.
Options are: \n
* ``'meanImgE'``: Enhanced mean image.
* ``'meanImg'``: Mean image.
Returns:
FOV_images (List[np.ndarray]):
List of FOV images. Length of the list is the same as
self.paths_files. Each element is a numpy.ndarray of shape
*(n_files, height, width)*.
"""
print(f"Starting: Importing FOV images from ops files") if self._verbose else None
assert self.paths_ops is not None, "RH ERROR: paths_ops is None. Please set paths_ops before calling this function."
assert len(self.paths_ops) > 0, "RH ERROR: paths_ops is empty. Please set paths_ops before calling this function."
assert all([Path(path).exists() for path in self.paths_ops]), f"RH ERROR: One or more paths in paths_ops do not exist: {[path for path in self.paths_ops if not Path(path).exists()]}"
FOV_images = [np.load(path, allow_pickle=True)[()][type_meanImg] for path in self.paths_ops]
assert all([FOV_images[0].shape[0] == FOV_images[i].shape[0] for i in range(1, len(FOV_images))]), f"RH ERROR: FOV images are not all the same height. Shapes: {[FOV_image.shape for FOV_image in FOV_images]}"
assert all([FOV_images[0].shape[1] == FOV_images[i].shape[1] for i in range(1, len(FOV_images))]), f"RH ERROR: FOV images are not all the same width. Shapes: {[FOV_image.shape for FOV_image in FOV_images]}"
FOV_images = np.stack(FOV_images, axis=0).astype(np.float32)
self.set_FOVHeightWidth(FOV_height=FOV_images[0].shape[0], FOV_width=FOV_images[0].shape[1])
print(f"Completed: Imported {len(FOV_images)} FOV images.") if self._verbose else None
return FOV_images
[docs]
def import_neuropil_masks(
self,
frame_height_width: Optional[Union[List[int], Tuple[int, int]]] = None,
) -> List[scipy.sparse.csr_array]:
"""
Imports and converts the neuropil masks of the ROIs in the stat files
into images in sparse arrays.
Args:
frame_height_width (Optional[Union[List[int], Tuple[int, int]]]):
The *height* and *width* of the frame in the form *[height,
width]*. If ``None``, the height and width will be taken from
the FOV images. (Default is ``None``)
Returns:
(List[scipy.sparse.csr_array]):
neuropilMasks (List[scipy.sparse.csr_array]):
List of neuropil masks. Length of the list is the same as
``self.paths_stat``. Each element is a sparse array of shape
*(n_roi, frame_height, frame_width)*.
"""
print("Importing neuropil masks from stat files.") if self._verbose else None
## Check and fix inputs
if frame_height_width is None:
frame_height_width = [self.FOV_height, self.FOV_width]
assert self.paths_stat is not None, "RH ERROR: paths_stat is None. Please set paths_stat before calling this function."
assert len(self.paths_stat) > 0, "RH ERROR: paths_stat is empty. Please set paths_stat before calling this function."
assert all([Path(path).exists() for path in self.paths_stat]), f"RH ERROR: One or more paths in paths_stat do not exist: {[path for path in self.paths_stat if not Path(path).exists()]}"
assert hasattr(self, 'shifts'), "RH ERROR: shifts is not defined. Please call ._make_shifts before calling this function."
statFiles = [np.load(path, allow_pickle=True) for path in self.paths_stat]
n = self.n_sessions
neuropilMasks = [
_transform_statFile_to_neuropilMasks(
frame_height_width=frame_height_width,
stat=statFiles[ii],
shifts=self.shifts[ii],
) for ii in tqdm(range(n))]
if self._verbose:
print(f"Imported {len(neuropilMasks)} sessions of neuropil masks into sparse arrays.")
self.neuropilMasks = neuropilMasks
return neuropilMasks
def _make_shifts(
self,
paths_ops: Optional[List[str]] = None,
new_or_old_suite2p: str = 'new',
) -> List[np.ndarray]:
"""
Helper function to make the shifts for the old suite2p indexing.
Args:
paths_ops (list of str, optional):
List of paths to the ops.npy files. Default is ``None``.
new_or_old_suite2p (str):
Type of suite2p output files. Should be: ``'new'`` or ``'old'``.
Default is ``'new'``.
Returns:
(List[np.ndarray]):
shifts (List[np.ndarray]):
List of shifts. Length of the list is the same as
``self.paths_files``. Each element is a numpy array of
shape *(2,)*.
"""
if paths_ops is None:
shifts = [np.array([0,0], dtype=np.uint64)]*self.n_sessions
return shifts
if new_or_old_suite2p == 'old':
shifts = [np.array([op['yrange'].min()-1, op['xrange'].min()-1], dtype=np.uint64) for op in [np.load(path, allow_pickle=True)[()] for path in paths_ops]]
elif new_or_old_suite2p == 'new':
shifts = [np.array([0,0], dtype=np.uint64)]*len(paths_ops)
else:
raise ValueError(f"RH ERROR: new_or_old_suite2p should be 'new' or 'old'. Got {new_or_old_suite2p}")
return shifts
def _transform_statFile_to_spatialFootprints(
frame_height_width: Tuple[int, int],
stat: np.ndarray,
shifts: Tuple[int, int] = (0, 0),
dtype: Optional[np.dtype] = None,
normalize_mask: bool = True,
) -> scipy.sparse.csr_array:
"""
Populates a sparse array with the spatial footprints from ROIs in a stat
file.
Args:
frame_height_width (Tuple[int, int]):
Height and width of the frame.
stat (np.ndarray):
Stat file containing ROIs information.
shifts (Tuple[int, int]):
Shifts in x and y coordinates to apply to ROIs. Default is (0,
0).
dtype (Optional[np.dtype]):
Data type of the array elements. If ``None``, it will be
inferred from the data. Default is ``None``.
normalize_mask (bool):
If True, normalize the mask. Default is ``True``.
Returns:
(scipy.sparse.csr_array):
spatialFootprints (scipy.sparse.csr_array):
Sparse array of shape *(n_roi, frame_height * frame_width)*
containing the spatial footprints of the ROIs.
"""
isInt = np.issubdtype(dtype, np.integer)
rois_to_stack = []
for jj, roi in enumerate(stat):
lam = np.array(roi['lam'], ndmin=1)
dtype = dtype if dtype is not None else lam.dtype
if isInt:
lam = dtype(lam / lam.sum() * np.iinfo(dtype).max) if normalize_mask else dtype(lam)
else:
lam = lam / lam.sum() if normalize_mask else lam
ypix = np.array(roi['ypix'], dtype=np.uint64, ndmin=1) + shifts[0]
xpix = np.array(roi['xpix'], dtype=np.uint64, ndmin=1) + shifts[1]
tmp_roi = scipy.sparse.csr_array((lam, (ypix, xpix)), shape=(frame_height_width[0], frame_height_width[1]), dtype=dtype)
rois_to_stack.append(tmp_roi.reshape(1,-1))
return scipy.sparse.vstack(rois_to_stack).tocsr()
def _transform_statFile_to_neuropilMasks(
frame_height_width: Tuple[int, int],
stat: np.ndarray,
shifts: Tuple[int, int] = (0, 0)
) -> scipy.sparse.csr_array:
"""
Populates a sparse array with the neuropil masks from ROIs in a stat
file.
Args:
frame_height_width (Tuple[int, int]):
Height and width of the frame.
stat (np.ndarray):
Stat file containing ROIs information.
shifts (Tuple[int, int]):
Shifts in x and y coordinates to apply to ROIs. Default is (0,
0).
Returns:
(scipy.sparse.csr_array):
neuropilMasks (scipy.sparse.csr_array):
Sparse array of shape *(n_roi, frame_height * frame_width)*
containing the neuropil masks of the ROIs.
"""
rois_to_stack = []
for jj, roi in enumerate(stat):
lam = np.ones(len(roi['neuropil_mask']), dtype=bool)
dtype = bool
ypix, xpix = np.unravel_index(roi['neuropil_mask'], shape=(frame_height_width[0], frame_height_width[1]), order='C')
ypix = ypix + shifts[0]
xpix = xpix + shifts[1]
tmp_roi = scipy.sparse.csr_array((lam, (ypix, xpix)), shape=(frame_height_width[0], frame_height_width[1]), dtype=dtype)
rois_to_stack.append(tmp_roi.reshape(1,-1))
return scipy.sparse.vstack(rois_to_stack).tocsr()
#########################################################
################## DATA CAIMAN ##########################
#########################################################
[docs]
class Data_caiman(Data_roicat):
"""
Class for importing data from CaImAn output files, specifically hdf5 results
files.
Args:
paths_resultsFiles (List[str]):
List of paths to the results files.
include_discarded (bool):
If ``True``, include ROIs that were discarded by CaImAn. Default is
``True``.
um_per_pixel (Union[float, List[float]]):
Resolution in micrometers per pixel of the imaging field of view.
The conversion factor from pixels to microns. This is used to scale
the ROI_images to a common size. Should either be a float or a list
of floats, one for each session.
out_height_width (List[int]):
Output height and width. Default is [36, 36].
centroid_method (str):
Method for calculating the centroid of an ROI. Should be:
``'centerOfMass'`` or ``'median'``.
verbose (bool):
If ``True``, print statements will be printed. Default is ``True``.
class_labels (str, optional):
Class labels. Default is ``None``.
"""
def __init__(
self,
paths_resultsFiles: List[str],
include_discarded: bool = True,
um_per_pixel: float = 1.0,
out_height_width: List[int] = [36,36],
centroid_method: str = 'median',
verbose: bool = True,
class_labels: Optional[str] = None,
) -> None:
## Inherit from Data_roicat
super().__init__()
self.paths_resultsFiles = fix_paths(paths_resultsFiles)
self.n_sessions = len(self.paths_resultsFiles)
# self._include_discarded = include_discarded
self._verbose = verbose
## Store parameter (but not data) args as attributes
self.params['__init__'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'paths_resultsFiles',
'include_discarded',
'um_per_pixel',
'out_height_width',
'centroid_method',
'verbose',
],
)
# 1. import_caiman_results
# # self.spatialFootprints
# ?? # self.overall_caiman_labels
# ?? # self.cnn_caiman_preds
# # self.n_roi
# # self.n_roi_total
spatialFootprints = [self.import_spatialFootprints(path, include_discarded=include_discarded) for path in self.paths_resultsFiles]
self.set_spatialFootprints(spatialFootprints=spatialFootprints, um_per_pixel=um_per_pixel)
overall_caimanLabels = [self.import_overall_caiman_labels(path, include_discarded=include_discarded) for path in self.paths_resultsFiles]
self.set_caimanLabels(overall_caimanLabels=overall_caimanLabels)
cnn_caimanPreds = [self.import_cnn_caiman_preds(path, include_discarded=include_discarded) for path in self.paths_resultsFiles]
self.set_caimanPreds(cnn_caimanPreds=cnn_caimanPreds) if cnn_caimanPreds[0] is not None else None
FOV_images = self.import_FOV_images(self.paths_resultsFiles)
self.set_FOV_images(FOV_images=FOV_images)
self._make_spatialFootprintCentroids(method=centroid_method)
self.transform_spatialFootprints_to_ROIImages(out_height_width=out_height_width)
self.set_class_labels(labels=class_labels) if class_labels is not None else None
[docs]
def set_caimanLabels(self, overall_caimanLabels: List[List[bool]]) -> None:
"""
Sets the CaImAn labels.
Args:
overall_caimanLabels (List[List[bool]]):
List of lists of CaImAn labels.
The outer list corresponds to sessions, and the inner list corresponds to ROIs.
"""
assert len(overall_caimanLabels) == self.n_sessions
print('kept labels', sum(overall_caimanLabels).sum(), len(sum(overall_caimanLabels))-sum(overall_caimanLabels).sum())
assert all([len(overall_caimanLabels[i]) == self.n_roi[i] for i in range(self.n_sessions)])
self.cnn_caimanLabels = overall_caimanLabels
[docs]
def set_caimanPreds(self, cnn_caimanPreds: List[List[bool]]) -> None:
"""
Sets the CNN-CaImAn predictions.
Args:
cnn_caimanPreds (List[List[bool]]):
List of lists of CNN-CaImAn predictions. The outer list
corresponds to sessions, and the inner list corresponds to ROIs.
"""
assert len(cnn_caimanPreds) == self.n_sessions, f"{len(cnn_caimanPreds)} != {self.n_sessions}"
assert all([len(cnn_caimanPreds[i]) == self.n_roi[i] for i in range(self.n_sessions)]), f"{[len(cnn_caimanPreds[i]) for i in range(self.n_sessions)]} != {[self.n_roi[i] for i in range(self.n_sessions)]}"
self.cnn_caimanPreds = cnn_caimanPreds
[docs]
def import_overall_caiman_labels(
self,
path_resultsFile: Union[str, pathlib.Path],
include_discarded: bool = True
) -> np.ndarray:
"""
Imports the overall CaImAn labels from the results file.
Args:
path_resultsFile (Union[str, pathlib.Path]):
Path to a single results file.
include_discarded (bool):
If ``True``, include ROIs that were discarded by CaImAn. Default
is ``True``.
Returns:
(np.ndarray):
labels (np.ndarray):
Overall CaImAn labels.
"""
with helpers.h5_load(path_resultsFile, return_dict=False) as data:
labels_included = np.ones(data['estimates']['A']['indptr'][()].shape[0] - 1)
if include_discarded:
try:
discarded = data['estimates']['discarded_components'][()]
labels_discarded = np.zeros(discarded['A']['indptr'].shape[0] - 1)
labels = np.hstack([labels_included, labels_discarded])
except:
print('no discarded components for labels')
labels = labels_included
else:
labels = labels_included
return labels
[docs]
def import_cnn_caiman_preds(
self,
path_resultsFile: Union[str, pathlib.Path],
include_discarded: bool = True,
) -> Union[np.ndarray, None]:
"""
Imports the CNN-based CaImAn prediction probabilities from the given
file.
Args:
path_resultsFile (Union[str, pathlib.Path]):
Path to a single results file. Can be either a string or a
pathlib.Path object.
include_discarded (bool):
If set to True, the function will include ROIs that were
discarded by CaImAn. By default, this is set to True.
Returns:
(np.ndarray):
preds (np.ndarray):
CNN-based CaImAn prediction probabilities.
"""
with helpers.h5_load(path_resultsFile, return_dict=False) as data:
preds_included = data['estimates']['cnn_preds'][()]
if preds_included == b'NoneType':
warnings.warn('No CNN preds found in results file')
return None
if include_discarded:
try:
discarded = data['estimates']['discarded_components'][()]
preds_discarded = discarded['cnn_preds']
preds = np.hstack([preds_included, preds_discarded])
except:
print('no discarded components for cnn_preds')
preds = preds_included
else:
preds = preds_included
return preds
[docs]
def import_ROI_centeredImages(self, out_height_width: List[int] = [36,36]) -> np.ndarray:
"""
Imports the ROI centered images from the CaImAn results files.
Args:
out_height_width (List[int]):
Height and width of the output images. Default is *[36,36]*.
Returns:
(np.ndarray):
ROI centered images (np.ndarray):
ROI centered images. Shape is *(nROIs, out_height_width[0],
out_height_width[1])*.
"""
def sf_to_centeredROIs(sf, centroids, out_height_width=36):
out_height_width = np.array([36,36])
half_widths = np.ceil(out_height_width/2).astype(int)
sf_rs = sparse.COO(sf).reshape((sf.shape[0], self.FOV_height, self.FOV_width))
coords_diff = np.diff(sf_rs.coords[0])
assert np.all(coords_diff < 1.01) and np.all(coords_diff > -0.01), \
"RH ERROR: sparse.COO object has strange .coords attribute. sf_rs.coords[0] should all be 0 or 1. An ROI is possibly all zeros."
idx_split = (sf_rs>0).astype(bool).sum((1,2)).todense().cumsum()[:-1]
coords_split = [np.split(sf_rs.coords[ii], idx_split) for ii in [0,1,2]]
coords_split[1] = [coords - centroids[0][ii] + half_widths[0] for ii,coords in enumerate(coords_split[1])]
coords_split[2] = [coords - centroids[1][ii] + half_widths[1] for ii,coords in enumerate(coords_split[2])]
sf_rs_centered = sf_rs.copy()
sf_rs_centered.coords = np.array([np.concatenate(c) for c in coords_split])
sf_rs_centered = sf_rs_centered[:, :out_height_width[0], :out_height_width[1]]
return sf_rs_centered.todense()
print(f"Computing ROI centered images from spatial footprints") if self._verbose else None
ROI_images = [sf_to_centeredROIs(sf, centroids.T, out_height_width=out_height_width) for sf, centroids in zip(self.spatialFootprints, self.centroids)]
return ROI_images
[docs]
def import_FOV_images(
self,
paths_resultsFiles: Optional[List] = None,
images: Optional[List] = None,
) -> List[np.ndarray]:
"""
Imports the FOV images from the CaImAn results files.
Args:
paths_resultsFiles (Optional[List]):
List of paths to CaImAn results files. If not provided, will use
the paths stored in the class instance.
images (Optional[List]):
List of FOV images. If None, the function will import the
`estimates.b` image from the paths specified in
`paths_resultsFiles`.
Returns:
List[np.ndarray]:
FOV images (np.ndarray):
FOV images. Shape is *(nROIs, FOV_height, FOV_width)*.
"""
def _import_FOV_image(path_resultsFile):
with helpers.h5_load(path_resultsFile, return_dict=False) as data:
FOV_height, FOV_width = data['estimates']['dims'][()]
FOV_image = data['estimates']['b'][()][:,0].reshape(FOV_height, FOV_width, order='F')
return FOV_image.astype(np.float32)
if images is not None:
if self._verbose:
print("Using provided images for FOV_images.")
FOV_images = images
else:
if paths_resultsFiles is None:
paths_resultsFiles = self.paths_resultsFiles
FOV_images = np.stack([_import_FOV_image(p) for p in paths_resultsFiles])
FOV_images = FOV_images - FOV_images.min(axis=(1,2), keepdims=True)
FOV_images = FOV_images / FOV_images.mean(axis=(1,2), keepdims=True)
return FOV_images
############################################
############ DATA ROIEXTRACTORS ############
############################################
####################################
######### HELPER FUNCTIONS #########
####################################
[docs]
def fix_paths(paths: Union[List[Union[str, pathlib.Path]], str, pathlib.Path]) -> List[str]:
"""
Ensures the input paths are a list of strings.
Args:
paths (Union[List[Union[str, pathlib.Path]], str, pathlib.Path]):
The input can be either a list of strings or pathlib.Path objects,
or a single string or pathlib.Path object.
Returns:
List[str]:
A list of strings representing the paths.
Raises:
TypeError:
If the input isn't a list of str or pathlib.Path objects, a single
str, or a pathlib.Path object.
"""
if isinstance(paths, (str, pathlib.Path)):
paths_files = [Path(paths).resolve()]
elif isinstance(paths[0], (str, pathlib.Path)):
paths_files = [Path(path).resolve() for path in paths]
else:
raise TypeError("path_files must be a list of str or list of pathlib.Path or a str or pathlib.Path")
return [str(p) for p in paths_files]
[docs]
def make_smaller_data(
data: Data_roicat,
n_ROIs: Optional[int] = 300,
n_sessions: Optional[int] = 10,
bounds_x: Tuple[int, int] = (200,400),
bounds_y: Tuple[int, int] = (200,400),
) -> Data_roicat:
"""
Reduces the size of a Data_roicat object by limiting the number of regions
of interest (ROIs) and sessions, and adjusting the bounds on the x and y
axes. This function is useful for making test datasets.
Args:
data (Data_roicat):
The input data object of the ``Data_roicat`` type.
n_ROIs (Optional[int]):
The number of regions of interest to include in the output data. If
``None``, all ROIs will be included.
n_sessions (Optional[int]):
The number of sessions to include in the output data. If ``None``,
all sessions will be included.
bounds_x (Tuple[int, int]):
The x-axis bounds for the output data. The bounds should be a tuple
of two integers.
bounds_y (Tuple[int, int]):
The y-axis bounds for the output data. The bounds should be a tuple
of two integers.
Returns:
(Data_roicat):
data_out (Data_roicat):
The output data, which is a reduced version of the input data according to the specified parameters.
"""
import sparse
data_out = copy.deepcopy(data)
n_sessions = min(n_sessions, len(data_out.spatialFootprints)) if n_sessions is not None else len(data_out.spatialFootprints)
d_height = data.FOV_height
d_width = data.FOV_width
d_n_ROIs = [sf.shape[0] for sf in data.spatialFootprints[:n_sessions]]
data_out.set_FOV_images(FOV_images=[im[bounds_y[0]:bounds_y[1], bounds_x[0]:bounds_x[1]] \
for im in data.FOV_images[:n_sessions]])
n_ROIs_per_sesh = [min(n_ROIs, n) for n in d_n_ROIs] if n_ROIs is not None else d_n_ROIs
frame = np.zeros((d_height, d_width), dtype=bool)
frame[bounds_y[0]:bounds_y[1], bounds_x[0]:bounds_x[1]] = True
frame_flat = frame.reshape(-1)
sf_tmp = [sf[:n] for sf, n in zip(data_out.spatialFootprints[:n_sessions], n_ROIs_per_sesh)]
good_rois = [np.array((sf.multiply(frame_flat[None,:])).sum(1) > 0).squeeze() \
for sf in sf_tmp]
sf_tmp = [sf[g,:] for sf, g in zip(sf_tmp, good_rois)]
data_out.set_spatialFootprints(
spatialFootprints=[sparse.COO(s).reshape(
shape=(s.shape[0], data.FOV_height, data.FOV_width)
)[:, bounds_y[0]:bounds_y[1], :][:, :, bounds_x[0]:bounds_x[1]].reshape(shape=(s.shape[0], -1)).tocsr() \
for s in sf_tmp
],
um_per_pixel=data_out.um_per_pixel,
)
data_out._make_spatialFootprintCentroids()
data_out.transform_spatialFootprints_to_ROIImages()
return data_out