Source code for roicat.tracking.scatteringWaveletTransformer
import gc
from typing import Any, Dict, Tuple
import torch
import numpy as np
from tqdm.auto import tqdm
from .. import helpers, util
[docs]
class SWT(util.ROICaT_Module):
"""
Performs scattering wavelet transform using the kymatio library.
RH 2022
Args:
kwargs_Scattering2D (Dict[str, Any]):
The keyword arguments to pass to the Scattering2D class.
(Default is ``{'J': 2, 'L': 8}``)
image_shape (Tuple[int, int]):
The shape of the images to be transformed.
(Default is ``(36,36)``)
device (str):
The device to use for the transformation.
(Default is ``'cpu'``)
verbose (bool):
If ``True``, print statements will be outputted.
(Default is ``True``)
Example:
.. highlight:: python
.. code-block:: python
swt = SWT(kwargs_Scattering2D={'J': 2, 'L': 8}, image_shape=(36,36), device='cpu', verbose=True)
transformed_images = swt.transform(ROI_images, batch_size=100)
"""
def __init__(
self,
kwargs_Scattering2D: Dict[str, Any] = {'J': 2, 'L': 8},
image_shape: Tuple[int, int] = (36,36),
device: str = 'cpu',
verbose: bool = True,
):
"""
Initializes the SWT with the given settings.
"""
## Imports
super().__init__()
## Store parameter (but not data) args as attributes
self.params['__init__'] = self._locals_to_params(
locals_dict=locals(),
keys=[
'kwargs_Scattering2D',
'image_shape',
'device',
'verbose',
],
)
## Monkey-patch for scipy >= 1.17 compatibility (sph_harm removed).
## sph_harm(m, n, theta, phi) was replaced by sph_harm_y(n, m, theta, phi)
## with swapped argument order for both (m,n) and (theta,phi).
import scipy.special
if not hasattr(scipy.special, 'sph_harm'):
scipy.special.sph_harm = lambda m, n, theta, phi: scipy.special.sph_harm_y(n, m, phi, theta)
from kymatio.torch import Scattering2D
self._verbose = verbose
self._device = device
self.swt = Scattering2D(shape=image_shape, **kwargs_Scattering2D)
self.swt = util.Model_SWT(self.swt)
self.swt.to(device)
print('SWT initialized') if self._verbose else None
[docs]
def transform(self, ROI_images: np.ndarray, batch_size: int = 100) -> np.ndarray:
"""
Transforms the ROI images.
Args:
ROI_images (np.ndarray):
The ROI images to transform.
One should probably concatenate ROI images across sessions for passing through here.
*(n_ROIs, height, width)*
batch_size (int):
The batch size to use for the transformation.
(Default is *100*)
Returns:
(np.ndarray):
latents (np.ndarray):
The transformed ROI images. *(n_ROIs, latent_size)*
"""
## Store parameter (but not data) args as attributes
self.params['transform'] = self._locals_to_params(
locals_dict=locals(),
keys=['batch_size',],)
print('Starting: SWT transform on ROIs') if self._verbose else None
def helper_swt(ims_batch):
sfs = torch.as_tensor(np.ascontiguousarray(ims_batch[None,...]), device=self._device, dtype=torch.float32)
out = self.swt(sfs[None,...]).squeeze().cpu()
if out.ndim == 3: ## if there is only one ROI in the batch, append a dimension to the front
out = out[None,...]
return out
self.latents = torch.cat([helper_swt(ims_batch) for ims_batch in tqdm(helpers.make_batches(ROI_images, batch_size=batch_size), total=ROI_images.shape[0] / batch_size, mininterval=5)], dim=0)
self.latents = self.latents.reshape(self.latents.shape[0], -1)
print('Completed: SWT transform on ROIs') if self._verbose else None
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
return self.latents