from typing import List, Tuple, Optional
import numpy as np
import sparse_convolution
from .. import helpers, util
[docs]
class ROI_Blurrer(util.ROICaT_Module):
"""
Blurs the Region of Interest (ROI) spatial footprints using 2D convolution
to account for registration uncertainty across imaging sessions. Uses the
``sparse_convolution`` library for fast sparse convolution via the
``'direct'`` method (batch-parallel numba scatter).
RH 2022
Args:
frame_shape (Tuple[int, int]):
The shape of the frame/Field Of View (FOV). Product of
``frame_shape[0]`` and ``frame_shape[1]`` must equal the length of a
single flattened/sparse spatialFootprint. (Default is *(512, 512)*)
kernel_halfWidth (int):
The half-width of the cosine kernel to use for convolutional
blurring. (Default is *2*)
plot_kernel (bool):
Whether to plot an image of the kernel. (Default is ``False``)
verbose (bool):
Whether to print the convolutional blurring operation progress.
(Default is ``True``)
Attributes:
frame_shape (Tuple[int, int]):
The shape of the frame/Field Of View (FOV). Product of
``frame_shape[0]`` and ``frame_shape[1]`` must equal the length of a
single flattened/sparse spatialFootprint.
kernel_halfWidth (int):
The half-width of the cosine kernel to use for convolutional
blurring.
plot_kernel (bool):
Whether to plot an image of the kernel.
verbose (bool):
Whether to print the convolutional blurring operation progress.
"""
def __init__(
self,
frame_shape: Tuple[int, int] = (512, 512),
kernel_halfWidth: int = 2,
plot_kernel: bool = False,
verbose: bool = True,
):
"""
Initializes the ROI_Blurrer with the given frame shape, kernel half-width,
plot kernel and verbosity setting.
"""
super().__init__()
## Store parameter (but not data) args as attributes
self.params['__init__'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'frame_shape',
'kernel_halfWidth',
'plot_kernel',
'verbose',
],
)
self._frame_shape = frame_shape
self._verbose = verbose
self._width = kernel_halfWidth * 2
self._kernel_size = max(int((self._width//2)*2) - 1, 1)
kernel_tmp = helpers.cosine_kernel_2D(
center=(self._kernel_size//2, self._kernel_size//2),
image_size=(self._kernel_size, self._kernel_size),
width=self._width
)
self.kernel = kernel_tmp / kernel_tmp.sum()
print('Preparing sparse convolution') if self._verbose else None
self._conv = sparse_convolution.Toeplitz_convolution2d(
x_shape=self._frame_shape,
k=self.kernel,
mode='same',
dtype=np.float32,
method='direct',
)
if plot_kernel:
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(self.kernel)
def __repr__(self):
width = self._width if hasattr(self, '_width') else '?'
has_blurred = hasattr(self, 'ROIs_blurred') and self.ROIs_blurred is not None
return f"ROI_Blurrer(kernel_halfWidth={width}, blurred={has_blurred})"
[docs]
def blur_ROIs(
self,
spatialFootprints: List[object],
) -> List[object]:
"""
Blurs the Region of Interest (ROI).
Args:
spatialFootprints (List[object]):
A list of sparse matrices corresponding to spatial footprints from each session.
Returns:
(List[object]):
ROIs_blurred (List[object]):
A list of blurred ROI spatial footprints.
"""
print('Performing convolution for blurring') if self._verbose else None
if self._width == 0:
self.ROIs_blurred = spatialFootprints
else:
self.ROIs_blurred = [
self._conv(
x=sf,
batching=True,
mode='same',
) for sf in spatialFootprints
]
return self.ROIs_blurred
[docs]
def get_ROIsBlurred_maxIntensityProjection(self) -> List[object]:
"""
Calculates the maximum intensity projection of the ROIs.
Returns:
(List[object]):
ims (List[object]):
The maximum intensity projection of the ROIs.
"""
ims = [rois.multiply(1.0 / np.maximum(rois.max(axis=1).toarray().reshape(-1, 1), util.SPARSE_NORMALIZATION_FLOOR)).max(axis=0).toarray().reshape(self._frame_shape[0], self._frame_shape[1]) for rois in self.ROIs_blurred]
return ims
# class ROI_Blurrer:
# """
# Class for blurring ROIs.
# Uses the sp_conv library for fast sparse convolutions.
# Repo here: https://github.com/traveller59/spconv
# RH 2022
# """
# def __init__(
# self,
# frame_shape=(512, 512),
# kernel_halfWidth=2,
# device='cpu',
# plot_kernel=False,
# ):
# """
# Initialize the class.
# Args:
# frame_shape (tuple):
# The shape of the frame/FOV.
# frame_shape[0] * frame_shape[1]
# must equal the length of a single flattened/
# sparse spatialFootprint.
# kernel_halfWidth (int):
# The half-width of the cosine kernel to use
# for convolutional blurring.
# device (str):
# The device to use for the convolution.
# plot_kernel (bool):
# Whether to plot an image of the kernel.
# """
# self._frame_shape = frame_shape
# self._device = device
# self._width = kernel_halfWidth * 2
# self._kernel_size = int((self._width//2)*2) + 3
# kernel_tmp = helpers.cosine_kernel_2D(
# center=(self._kernel_size//2, self._kernel_size//2),
# image_size=(self._kernel_size, self._kernel_size),
# width=self._width
# )
# self.kernel = kernel_tmp / kernel_tmp.sum()
# ## prepare kernel
# kernel_prep = torch.as_tensor(
# self.kernel[:,:,None,None],
# dtype=torch.float32,
# device=device
# ).contiguous()
# ## prepare convolution
# self._conv = spconv.SparseConv2d(
# in_channels=1,
# out_channels=1,
# kernel_size=self.kernel.shape,
# stride=1,
# padding=self.kernel.shape[0]//2,
# dilation=1,
# groups=1,
# bias=False
# )
# self._conv.weight = torch.nn.Parameter(data=kernel_prep, requires_grad=False)
# if plot_kernel:
# import matplotlib.pyplot as plt
# plt.figure()
# plt.imshow(self.kernel)
# def _sparse_conv2D(
# self,
# sf_sparseCOO,
# ):
# """
# Method to perform a 2D convolution on a sparse matrix.
# Args:
# sf_sparseCOO (sparse.COO):
# The sparse matrix to convolve.
# shape: (num_ROIs, frame_shape[0], frame_shape[1])
# """
# images_spconv = pydata_sparse_to_spconv(
# sf_sparseCOO,
# device=self._device
# )
# images_conv = self._conv(images_spconv)
# return sparse_convert_spconv_to_scipy(images_conv)
# def blur_ROIs(
# self,
# spatialFootprints,
# batch_size=None,
# num_batches=100,
# ):
# """
# Method to blur ROIs.
# Args:
# spatialFootprints (list of scipy.sparse.csr_array):
# The spatialFootprints to blur.
# shape of each element:
# (num_ROIs, frame_shape[0] * frame_shape[1])
# batch_size (int):
# The batch size to use for blurring.
# if None, then will use num_batches to determine size.
# num_batches (int):
# The number of batches to use for blurring.
# """
# sf_coo = [sparse.as_coo(sf).reshape((sf.shape[0], self._frame_shape[0], self._frame_shape[1])) for sf in spatialFootprints]
# self.ROIs_blurred = [scipy.sparse.vstack([self._sparse_conv2D(
# sf_sparseCOO=batch,
# ) for batch in helpers.make_batches(sf, batch_size=batch_size, num_batches=num_batches)]) for sf in sf_coo]
# return self.ROIs_blurred
# def get_ROIsBlurred_maxIntensityProjection(self):
# """
# Returns the max intensity projection of the ROIs.
# """
# return [rois.max(0).toarray().reshape(self._frame_shape[0], self._frame_shape[1]) for rois in self.ROIs_blurred]
# def pydata_sparse_to_spconv(sp_array, device='cpu'):
# coo = sparse.COO(sp_array)
# idx_raw = torch.as_tensor(coo.coords.T, dtype=torch.int32, device=device).contiguous()
# spconv_array = spconv.SparseConvTensor(
# features=torch.as_tensor(coo.reshape((-1)).T.data, dtype=torch.float32, device=device)[:,None].contiguous(),
# indices=idx_raw,
# spatial_shape=coo.shape[1:],
# batch_size=coo.shape[0]
# )
# return spconv_array
# def sparse_convert_spconv_to_scipy(sp_arr):
# coo = sparse.COO(
# coords=sp_arr.indices.T.to('cpu'),
# data=sp_arr.features.squeeze().to('cpu'),
# shape=[sp_arr.batch_size] + sp_arr.spatial_shape
# )
# return coo.reshape((coo.shape[0], -1)).to_scipy_sparse().tocsr()