from typing import List, Dict, Tuple, Union, Optional, Any, Callable, Iterable, Sequence, Type, Any, MutableMapping
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import multiprocessing as mp
import os
from pathlib import Path
import copy
import pickle
import re
import zipfile
import gc
from functools import partial
import tkinter as tk
from tkinter import ttk
import PIL
from PIL import ImageTk
import csv
import warnings
import time
import datetime
import numpy as np
import torch
import torchvision
import scipy.sparse
import scipy.interpolate
import sparse
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.widgets import LassoSelector
from matplotlib.path import Path as mplPath
import yaml
import optuna
"""
All of these are from basic_neural_processing_modules
"""
######################################################################################################################################
####################################################### TORCH HELPERS ################################################################
######################################################################################################################################
[docs]
def set_device(
use_GPU: bool = True,
device_num: int = 0,
device_types: List[str] = ['cuda', 'mps', 'xpu', 'cpu'],
verbose: bool = True
) -> str:
"""
Sets the device for PyTorch. If a GPU is available and **use_GPU** is
``True``, it will be set as the device. Otherwise, the CPU will be set as
the device.
RH 2022
Args:
use_GPU (bool):
Determines if the GPU should be utilized: \n
* ``True``: the function will attempt to use the GPU if a GPU is
not available.
* ``False``: the function will use the CPU. \n
(Default is ``True``)
device_num (int):
Specifies the index of the GPU to use. (Default is ``0``)
device_types (List[str]):
The types and order of devices to attempt to use. The first device
type that is available will be used. Options are ``'cuda'``,
``'mps'``, ``'xpu'``, and ``'cpu'``.
verbose (bool):
Determines whether to print the device information. \n
* ``True``: the function will print out the device information.
\n
(Default is ``True``)
Returns:
(str):
device (str):
A string specifying the device, either *"cpu"* or
*"cuda:<device_num>"*.
"""
devices = list_available_devices()
if not use_GPU:
device = 'cpu'
else:
device = None
for device_type in device_types:
if len(devices[device_type]) > 0:
device = devices[device_type][device_num]
break
if verbose:
print(f'Using device: {device}')
return device
[docs]
def clear_gpu_cache(gc_collect: bool = True):
"""
Clear GPU memory cache and optionally run garbage collection.
Releases unreferenced GPU tensors back to the OS. Works with CUDA,
MPS (Apple Silicon), and is a no-op on CPU-only systems. Call
between pipeline steps to prevent memory accumulation.
Args:
gc_collect (bool):
If ``True`` (default), run ``gc.collect()`` before clearing
the GPU cache. Set to ``False`` for lightweight clearing
inside tight loops where GC overhead or side effects are
undesirable.
RH 2025
"""
if gc_collect:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if hasattr(torch, 'mps') and torch.backends.mps.is_available():
torch.mps.empty_cache()
if gc_collect:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if hasattr(torch, 'mps') and torch.backends.mps.is_available():
torch.mps.empty_cache()
if gc_collect:
gc.collect()
[docs]
def list_available_devices() -> dict:
"""
Lists all available PyTorch devices on the system.
RH 2024
Returns:
(dict):
A dictionary with device types as keys and lists of available devices as values.
"""
devices = {}
# Check for CPU devices
if torch.cpu.is_available():
devices['cpu'] = ['cpu']
else:
devices['cpu'] = []
# Check for CUDA devices
if torch.cuda.is_available():
devices['cuda'] = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
else:
devices['cuda'] = []
# Check for MPS devices
if torch.backends.mps.is_available():
devices['mps'] = ['mps:0']
else:
devices['mps'] = []
# Check for XPU devices
if hasattr(torch, 'xpu'):
if torch.xpu.is_available():
devices['xpu'] = [f'xpu:{i}' for i in range(torch.xpu.device_count())]
else:
devices['xpu'] = []
else:
devices['xpu'] = []
return devices
######################################################################################################################################
###################################################### CONTAINER HELPERS #############################################################
######################################################################################################################################
[docs]
def merge_dicts(
dicts: List[dict]
) -> dict:
"""
Merges a list of dictionaries into a single dictionary.
RH 2022
Args:
dicts (List[dict]):
List of dictionaries to merge.
Returns:
(dict):
result_dict (dict):
A single dictionary that contains all keys and values from the
dictionaries in the input list.
"""
out = {}
[out.update(d) for d in dicts]
return out
[docs]
def deep_update_dict(
dictionary: dict,
key: List[str],
val: Any,
in_place: bool = False
) -> Union[dict, None]:
"""
Updates a nested dictionary with a new value.
RH 2023
Args:
dictionary (dict):
The original dictionary to update.
key (List[str]):
List of keys representing the hierarchical path to the nested value
to update. Each element should be a string that represents a level
in the hierarchy. For example, to change a value in the dictionary
`params` at key 'dataloader_kwargs' and subkey 'prefetch_factor', you would
pass `['dataloader_kwargs', 'prefetch_factor']`.
val (Any):
The new value to set in the dictionary.
in_place (bool):
* ``True``: the original dictionary will be updated in-place and no
value will be returned.
* ``False``, a new dictionary will be created and returned. (Default
is ``False``)
Returns:
(Union[dict, None]):
updated_dict (dict):
The updated dictionary. Only returned if ``in_place`` is ``False``.
Example:
.. highlight:: python
.. code-block:: python
original_dict = {"level1": {"level2": "old value"}}
updated_dict = deep_update_dict(original_dict, ["level1", "level2"], "new value", in_place=False)
# Now updated_dict is {"level1": {"level2": "new value"}}
"""
def helper_deep_update_dict(d, key, val):
if type(key) is str:
key = [key]
assert key[0] in d, f"RH ERROR, key: '{key[0]}' is not found"
if type(key) is list:
if len(key) > 1:
helper_deep_update_dict(d[key[0]], key[1:], val)
elif len(key) == 1:
key = key[0]
d.update({key:val})
if in_place:
helper_deep_update_dict(dictionary, key, val)
else:
d = copy.deepcopy(dictionary)
helper_deep_update_dict(d, key, val)
return d
[docs]
def flatten_dict(d: MutableMapping, parent_key: str = '', sep: str ='.') -> MutableMapping:
"""
Flattens a dictionary of dictionaries into a single dictionary. NOTE: Turns
all keys into strings. Stolen from https://stackoverflow.com/a/6027615.
RH 2022
Args:
d (Dict):
Dictionary to flatten
parent_key (str):
Key to prepend to flattened keys IGNORE: USED INTERNALLY FOR
RECURSION
sep (str):
Separator to use between keys IGNORE: USED INTERNALLY FOR RECURSION
Returns:
(Dict):
flattened dictionary (dict):
Flat dictionary with the keys to deeper dictionaries joined by
the separator.
"""
items = []
for k, v in d.items():
new_key = str(parent_key) + str(sep) + str(k) if parent_key else str(k)
if isinstance(v, MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
## parameter dictionary helpers ##
[docs]
def fill_in_dict(
d: Dict,
defaults: Dict,
verbose: bool = True,
hierarchy: List[str] = ['dict'],
):
"""
In-place. Fills in dictionary ``d`` with values from ``defaults`` if they
are missing. Works hierachically.
RH 2023
Args:
d (Dict):
Dictionary to fill in.
In-place.
defaults (Dict):
Dictionary of defaults.
verbose (bool):
Whether to print messages.
hierarchy (List[str]):
Used internally for recursion.
Hierarchy of keys to d.
"""
from copy import deepcopy
for key in defaults:
if key not in d:
print(f"Key '{key}' not found in params dictionary: {' > '.join([f'{str(h)}' for h in hierarchy])}. Using default value: {defaults[key]}") if verbose else None
d.update({key: deepcopy(defaults[key])})
elif isinstance(defaults[key], dict):
assert isinstance(d[key], dict), f"Key '{key}' is a dict in defaults, but not in params. {' > '.join([f'{str(h)}' for h in hierarchy])}."
fill_in_dict(d[key], defaults[key], hierarchy=hierarchy+[key], verbose=verbose)
[docs]
def check_keys_subset(d, default_dict, hierarchy=['defaults']):
"""
Checks that the keys in d are all in default_dict. Raises an error if not.
RH 2023
Args:
d (Dict):
Dictionary to check.
default_dict (Dict):
Dictionary containing the keys to check against.
hierarchy (List[str]):
Used internally for recursion.
Hierarchy of keys to d.
"""
default_keys = list(default_dict.keys())
for key in d.keys():
assert key in default_keys, f"Parameter '{key}' not found in defaults dictionary: {' > '.join([f'{str(h)}' for h in hierarchy])}."
if isinstance(default_dict[key], dict) and isinstance(d[key], dict):
check_keys_subset(d[key], default_dict[key], hierarchy=hierarchy+[key])
[docs]
def prepare_params(params, defaults, verbose=True):
"""
Does the following:
* Checks that all keys in ``params`` are in ``defaults``.
* Fills in any missing keys in ``params`` with values from ``defaults``.
* Returns a deepcopy of the filled-in ``params``.
Args:
params (Dict):
Dictionary of parameters.
defaults (Dict):
Dictionary of defaults.
verbose (bool):
Whether to print messages.
"""
from copy import deepcopy
## Check inputs
assert isinstance(params, dict), f"p must be a dict. Got {type(params)} instead."
## Make sure all the keys in p are valid
check_keys_subset(params, defaults)
## Fill in any missing keys with defaults
params_out = deepcopy(params)
fill_in_dict(params_out, defaults, verbose=verbose)
return params_out
######################################################################################################################################
####################################################### MATH FUNCTIONS ###############################################################
######################################################################################################################################
[docs]
def generalised_logistic_function(
x: Union[np.ndarray, torch.Tensor],
a: float = 0,
k: float = 1,
b: float = 1,
v: float = 1,
q: float = 1,
c: float = 1,
mu: float = 0,
) -> Union[np.ndarray, torch.Tensor]:
"""
Calculates the **generalized logistic function**.
Refer to `Generalised logistic function
<https://en.wikipedia.org/wiki/Generalised_logistic_function>`_ for detailed
information on the parameters.
RH 2021
Args:
x (Union[np.ndarray, torch.Tensor]):
The input to the logistic function.
a (float):
The lower asymptote. (Default is *0*)
k (float):
The upper asymptote when ``c=1``. (Default is *1*)
b (float):
The growth rate. (Default is *1*)
v (float):
Should be greater than *0*, it affects near which asymptote maximum growth
occurs. (Default is *1*)
q (float):
Related to the value Y(0). Center positions. (Default is *1*)
c (float):
Typically takes a value of *1*. (Default is *1*)
mu (float):
The center position of the function. (Default is *0*)
Returns:
(Union[np.ndarray, torch.Tensor]):
out (Union[np.ndarray, torch.Tensor]):
The value of the logistic function for the input ``x``.
"""
if type(x) is np.ndarray:
exp = np.exp
elif type(x) is torch.Tensor:
exp = torch.exp
return a + (k-a) / (c + q*exp(-b*(x-mu)))**(1/v)
[docs]
def bounded_logspace(
start: float,
stop: float,
num: int,
) -> np.ndarray:
"""
Creates a **logarithmically spaced array**, similar to ``np.logspace``, but
with a defined start and stop.
RH 2022
Args:
start (float):
The first value in the output array.
stop (float):
The last value in the output array.
num (int):
The number of values in the output array.
Returns:
(np.ndarray):
out (np.ndarray):
An array of logarithmically spaced values between ``start`` and
``stop``.
"""
exp = 2 ## doesn't matter what this is, just needs to be > 1
return exp ** np.linspace(np.log(start)/np.log(exp), np.log(stop)/np.log(exp), num, endpoint=True)
[docs]
def make_odd(n, mode='up'):
"""
Make a number odd.
RH 2023
Args:
n (int):
Number to make odd
mode (str):
'up' or 'down'
Whether to round up or down to the nearest odd number
Returns:
output (int):
Odd number
"""
if n % 2 == 0:
if mode == 'up':
return n + 1
elif mode == 'down':
return n - 1
else:
raise ValueError("mode must be 'up' or 'down'")
else:
return n
[docs]
def make_even(n, mode='up'):
"""
Make a number even.
RH 2023
Args:
n (int):
Number to make even
mode (str):
'up' or 'down'
Whether to round up or down to the nearest even number
Returns:
output (int):
Even number
"""
if n % 2 != 0:
if mode == 'up':
return n + 1
elif mode == 'down':
return n - 1
else:
raise ValueError("mode must be 'up' or 'down'")
else:
return n
######################################################################################################################################
####################################################### CLASSIFICATION ###############################################################
######################################################################################################################################
[docs]
def confusion_matrix(
y_hat: np.ndarray,
y_true: np.ndarray,
counts: bool = False,
) -> np.ndarray:
"""
Computes the confusion matrix from ``y_hat`` and ``y_true``. ``y_hat``
should be either predictions or probabilities.
RH 2022
Args:
y_hat (np.ndarray):
Numpy array of predictions or probabilities. Either \n
* 1D array of predictions *(n_samples,)*. Values should be integers.
* 2D array of probabilities *(n_samples, n_classes)*. Values should
be floats. \n
(Default is 1D array of predictions)
y_true (np.ndarray):
Numpy array of ground truth labels. Either \n
* 1D array of labels *(n_samples,)*. Values should be integers.
* 2D array of one-hot labels *(n_samples, n_classes)*. Values should
be integers. \n
(Default is 1D array of labels)
counts (bool):
If ``False``, the output confusion matrix is normalized. If
``True``, the output contains counts. (Default is ``False``)
Returns:
(np.ndarray):
cmat (np.ndarray):
The computed confusion matrix.
"""
n_classes = max(np.max(y_true)+1, np.max(y_hat)+1)
if y_hat.ndim == 1:
y_hat = idx_to_oneHot(y_hat, n_classes).astype('int')
cmat = y_hat.T @ idx_to_oneHot(y_true, n_classes)
if not counts:
cmat = cmat / np.sum(cmat, axis=0)[None,:]
return cmat
[docs]
def squeeze_integers(
intVec: Union[list, np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
"""
Makes integers in an array consecutive numbers starting from the smallest
value. For example, [7,2,7,4,-1,0] -> [3,2,3,1,-1,0]. This is useful for
removing unused class IDs.
RH 2023
Args:
intVec (Union[list, np.ndarray, torch.Tensor]):
1-D array of integers.
Returns:
(Union[np.ndarray, torch.Tensor]):
squeezed_integers (Union[np.ndarray, torch.Tensor]):
1-D array of integers with consecutive numbers starting from the
smallest value.
"""
if isinstance(intVec, list):
intVec = np.array(intVec, dtype=np.int64)
if isinstance(intVec, np.ndarray):
unique, arange = np.unique, np.arange
elif isinstance(intVec, torch.Tensor):
unique, arange = torch.unique, torch.arange
u, inv = unique(intVec, return_inverse=True) ## get unique values and their indices
u_min = u.min() ## get the smallest value
u_s = arange(u_min, u_min + u.shape[0], dtype=u.dtype) ## make consecutive numbers starting from the smallest value
return u_s[inv] ## return the indexed consecutive unique values
######################################################################################################################################
######################################################### OPTIMIZATION ###############################################################
######################################################################################################################################
[docs]
class Convergence_checker_optuna:
"""
Checks if the optuna optimization has converged.
RH 2023
Args:
n_patience (int):
Number of trials to look back to check for convergence.
Also the minimum number of trials that must be completed
before starting to check for convergence.
(Default is *10*)
tol_frac (float):
Fractional tolerance for convergence.
The best output value must change by less than this
fractional amount to be considered converged.
(Default is *0.05*)
max_trials (int):
Maximum number of trials to run before stopping.
(Default is *350*)
max_duration (float):
Maximum number of seconds to run before stopping.
(Default is *600*)
value_stop (Optional[float]):
Value at which to stop the optimization. If the best value is equal
to or less than this value, the optimization will stop.
(Default is *None*)
verbose (bool):
If ``True``, print messages.
(Default is ``True``)
Attributes:
bests (List[float]):
List to hold the best values obtained in the trials.
best (float):
Best value obtained among the trials. Initialized with infinity.
Example:
.. highlight:: python
.. code-block:: python
# Create a ConvergenceChecker instance
convergence_checker = ConvergenceChecker(
n_patience=15,
tol_frac=0.01,
max_trials=500,
max_duration=60*20,
verbose=True
)
# Assume we have a study and trial objects from optuna
# Use the check method in the callback
study.optimize(objective, n_trials=100, callbacks=[convergence_checker.check])
"""
def __init__(
self,
n_patience: int = 10,
tol_frac: float = 0.05,
max_trials: int = 350,
max_duration: float = 60*10,
value_stop: Optional[float] = None,
verbose: bool = True,
):
"""
Initializes the ConvergenceChecker with the given parameters.
"""
self.bests = []
self.best = np.inf
self.n_patience = n_patience
self.tol_frac = tol_frac
self.max_trials = max_trials
self.max_duration = max_duration
self.value_stop = value_stop
self.num_trial = 0
self.verbose = verbose
[docs]
def check(
self,
study: object,
trial: object,
):
"""
Checks if the optuna optimization has converged. This function should be
used as the callback function for the optuna study.
Args:
study (optuna.study.Study):
Optuna study object.
trial (optuna.trial.FrozenTrial):
Optuna trial object.
"""
dur_first, dur_last = study.trials[0].datetime_complete, trial.datetime_complete
if (dur_first is not None) and (dur_last is not None):
duration = (dur_last - dur_first).total_seconds()
else:
duration = 0
if trial.value < self.best:
self.best = trial.value
self.bests.append(self.best)
bests_recent = np.unique(self.bests[-self.n_patience:])
if self.best == 0:
print(f'Stopping. Best value is 0.') if self.verbose else None
study.stop()
elif self.num_trial > self.n_patience and ((np.abs(bests_recent.max() - bests_recent.min()) / np.abs(self.best)) < self.tol_frac):
print(f'Stopping. Convergence reached. Best value ({self.best*10000}) over last ({self.n_patience}) trials fractionally changed less than ({self.tol_frac})') if self.verbose else None
study.stop()
elif self.num_trial >= self.max_trials:
print(f'Stopping. Trial number limit reached. num_trial={self.num_trial}, max_trials={self.max_trials}.') if self.verbose else None
study.stop()
elif duration > self.max_duration:
print(f'Stopping. Duration limit reached. study.duration={duration}, max_duration={self.max_duration}.') if self.verbose else None
study.stop()
if self.value_stop is not None:
if self.best <= self.value_stop:
print(f'Stopping. Best value ({self.best}) is less than or equal to value_stop ({self.value_stop}).') if self.verbose else None
study.stop()
if self.verbose:
print(f'Trial num: {self.num_trial}. Duration: {duration:.3f}s. Best value: {self.best:3e}. Current value:{trial.value:3e}') if self.verbose else None
self.num_trial += 1
[docs]
class OptunaProgressBar:
"""
A customizable progress bar for Optuna's study.optimize().
Args:
n_trials (int, optional):
The number of trials. Exactly one of n_trials or timeout must be
set.
timeout (float, optional):
The maximum time to run in seconds. Exactly one of n_trials or
timeout must be set.
tqdm_kwargs (dict, optional):
Additional keyword arguments to pass to tqdm.
These will override the default values EXCEPT for the following
kwargs, which will defer to the environment variables: \n
* ``disable``
* ``dynamic_ncols``
"""
def __init__(
self,
n_trials: Optional[int] = None,
timeout: Optional[float] = None,
**tqdm_kwargs: Any,
):
if (n_trials is None) and (timeout is None):
raise ValueError("Either n_trials or timeout must be set.")
## Exclusivity between n_trials and timeout
elif (n_trials is not None) and (timeout is not None):
raise ValueError("Only one of n_trials or timeout should be set.")
## Store user-provided values
self._n_trials = n_trials
self._timeout = timeout
self._tqdm_kwargs = self._get_default_kwargs()
## Overwrite default values with user-provided values
self._tqdm_kwargs.update(tqdm_kwargs)
## Initialize params for bar
self.bar = None
self._last_elapsed_seconds = 0.0
## Initialize global variables for time diff
global _time_diff_last_call
_time_diff_last_call = time.time()
global _times_calls
_times_calls = []
global _time_start
_time_start = time.time()
## Initialize progress bar
if self._n_trials is not None:
self.bar = tqdm(
total=self._n_trials,
**self._tqdm_kwargs,
)
elif self._timeout is not None:
total = tqdm.format_interval(self._timeout)
fmt = "{desc} {percentage:3.0f}%|{bar}| {elapsed}/" + total
self.bar = tqdm(total=self._timeout, bar_format=fmt)
def __call__(self, study: optuna.study.Study, trial: optuna.trial.FrozenTrial):
# Initialize progress bar on first call
if self.bar is None:
# Disable progress bar based on environment variables and context
if self._get_env_variables().get('disable', False):
self._tqdm_kwargs['disable'] = True
# Update description based on mininterval
current_time = time.time()
# Drop current time into _times_calls list
global _times_calls
min_interval = self._tqdm_kwargs.get('mininterval', 0.1)
global _time_diff_last_call
if (current_time - _time_diff_last_call >= min_interval) and all([current_time - t >= min_interval for t in _times_calls]):
_times_calls.append(current_time)
# print(_time_diff_last_call, _times_calls, current_time)
self._update(study)
_time_diff_last_call = current_time
## Remove the first element of the list
_times_calls.pop(0)
def _update(self, study: optuna.study.Study):
# Update progress bar
if self._n_trials is not None:
## Get the current trial number
i_trial = len(study.trials)
self.bar.update(i_trial - self.bar.n)
elif self._timeout is not None:
# Get the total elapsed time for the study (last trial - first trial)
t_start = study.trials[0].datetime_start
t_last = study.trials[-1].datetime_complete
t_last = datetime.datetime.now()
elapsed_seconds = t_last - t_start
self.bar.update(elapsed_seconds.total_seconds() - self.bar.n)
try:
best_value = study.best_value
best_trial = study.best_trial
self.bar.set_description(f"Best trial: {best_trial.number}, value: {best_value:.6g}")
except ValueError:
pass # No trials completed yet
def _get_env_variables(self):
# Load TQDM environment variables
env_vars = {
'disable': bool( os.environ.get('TQDM_DISABLE', '').lower() == 'true'),
'mininterval': float(os.environ.get('TQDM_MININTERVAL', 0.1)),
'maxinterval': float(os.environ.get('TQDM_MAXINTERVAL', 10.0)),
'miniters': int( os.environ.get('TQDM_MINITERS', 1)),
'smoothing': float(os.environ.get('TQDM_SMOOTHING', 0.3)),
}
return env_vars
def _get_default_kwargs(self):
kwargs_default = self._get_env_variables()
# Set default values for better handling in environments
kwargs_default.setdefault('dynamic_ncols', True)
kwargs_default.setdefault('leave', False)
return kwargs_default
######################################################################################################################################
######################################################## FEATURIZATION ###############################################################
######################################################################################################################################
[docs]
def idx_to_oneHot(
arr: Union[np.ndarray, torch.Tensor],
n_classes: Optional[int] = None,
dtype: Optional[Type] = None,
) -> Union[np.ndarray, torch.Tensor]:
"""
Converts an array of class labels to a matrix of one-hot vectors.
RH 2021
Args:
arr (Union[np.ndarray, torch.Tensor]):
A 1-D array of class labels. Values should be integers >= 0. Values
will be used as indices in the output array.
n_classes (Optional[int]):
The number of classes. If ``None``, it will be derived from ``arr``.
(Default is ``None``)
dtype (Optional[Type]):
The data type of the output array. If ``None``, it defaults to bool
for numpy array and torch.bool for Torch tensor. (Default is
``None``)
Returns:
(Union[np.ndarray, torch.Tensor]):
oneHot (Union[np.ndarray, torch.Tensor]):
A 2-D array of one-hot vectors.
"""
if type(arr) is np.ndarray:
max = np.max
zeros = np.zeros
arange = np.arange
dtype = bool if dtype is None else dtype
elif type(arr) is torch.Tensor:
max = torch.max
zeros = torch.zeros
arange = torch.arange
dtype = torch.bool if dtype is None else dtype
assert arr.ndim == 1
if n_classes is None:
n_classes = max(arr)+1
oneHot = zeros((len(arr), n_classes), dtype=dtype)
oneHot[arange(len(arr)), arr] = True
return oneHot
[docs]
def cosine_kernel_2D(
center: Tuple[int, int] = (5, 5),
image_size: Tuple[int, int] = (11, 11),
width: int = 5,
) -> np.ndarray:
"""
Generates a 2D cosine kernel. RH 2021
Args:
center (Tuple[int, int]):
The mean position (X, Y) where high value is expected. It is 0-indexed.
Make the second value 0 to make it 1D. (Default is *(5, 5)*)
image_size (Tuple[int, int]):
The total image size (width, height). Make the second value 0 to
make it 1D. (Default is *(11, 11)*)
width (int):
The full width of one cycle of the cosine. (Default is *5*)
Returns:
(np.ndarray):
k_cos (np.ndarray):
2D or 1D array of the cosine kernel.
"""
x, y = np.meshgrid(range(image_size[1]), range(image_size[0])) # note dim 1:X and dim 2:Y
dist = np.sqrt((y - int(center[1])) ** 2 + (x - int(center[0])) ** 2)
dist_scaled = (dist/(width/2))*np.pi
dist_scaled[dist_scaled > np.pi] = np.pi
k_cos = (np.cos(dist_scaled) + 1)/2
return k_cos
######################################################################################################################################
########################################################## INDEXING ##################################################################
######################################################################################################################################
[docs]
def idx2bool(
idx: np.ndarray,
length: Optional[int] = None,
) -> np.ndarray:
"""
Converts a vector of indices to a boolean vector.
RH 2021
Args:
idx (np.ndarray):
1-D array of indices.
length (Optional[int]):
Length of boolean vector. If ``None``, the length will be set to the
maximum index in ``idx`` + 1. (Default is ``None``)
Returns:
(np.ndarray):
bool_vec (np.ndarray):
1-D boolean array.
"""
if length is None:
length = np.uint64(np.max(idx) + 1)
out = np.zeros(length, dtype=bool)
out[idx] = True
return out
[docs]
def make_batches(
iterable: Iterable,
batch_size: Optional[int] = None,
num_batches: Optional[int] = None,
min_batch_size: int = 0,
return_idx: bool = False,
length: Optional[int] = None
) -> Iterable:
"""
Creates batches from an iterable.
RH 2021
Args:
iterable (Iterable):
The iterable to be batched.
batch_size (Optional[int]):
The size of each batch. If ``None``, then the batch size is based on
``num_batches``. (Default is ``None``)
num_batches (Optional[int]):
The number of batches to create. (Default is ``None``)
min_batch_size (int):
The minimum size of each batch. (Default is ``0``)
return_idx (bool):
If ``True``, return the indices of the batches. Output will be
[start, end] idx. (Default is ``False``)
length (Optional[int]):
The length of the iterable. If ``None``, then the length is
len(iterable). This is useful if you want to make batches of
something that doesn't have a __len__ method. (Default is ``None``)
Returns:
(Iterable):
output (Iterable):
Batches of the iterable.
"""
if length is None:
l = len(iterable)
else:
l = length
if batch_size is None:
batch_size = np.int64(np.ceil(l / num_batches))
for start in range(0, l, batch_size):
end = min(start + batch_size, l)
if (end-start) < min_batch_size:
break
else:
if return_idx:
yield iterable[start:end], [start, end]
else:
yield iterable[start:end]
[docs]
def sparse_to_dense_fill(
arr_s: sparse.COO,
fill_val: float = 0.
) -> np.ndarray:
"""
Converts a **sparse** array to a **dense** array and fills in sparse entries with a specified fill value.
RH 2023
Args:
arr_s (sparse.COO):
Sparse array to be converted to dense.
fill_val (float):
Value to fill the sparse entries. (Default is ``0.0``)
Returns:
(np.ndarray):
dense_arr (np.ndarray):
Dense version of the input sparse array.
"""
import sparse
s = sparse.COO(arr_s)
s.fill_value = fill_val
return s.todense()
[docs]
def sparse_mask(
x: scipy.sparse.csr_array,
mask_sparse: scipy.sparse.csr_array,
do_safety_steps: bool = True
) -> scipy.sparse.csr_array:
"""
Masks a **sparse matrix** with the non-zero elements of another sparse
matrix.
RH 2022
Args:
x (scipy.sparse.csr_array):
Sparse matrix to mask.
mask_sparse (scipy.sparse.csr_array):
Sparse matrix to mask with.
do_safety_steps (bool):
Whether to do safety steps to ensure that things are working as
expected. (Default is ``True``)
Returns:
(scipy.sparse.csr_array):
output (scipy.sparse.csr_array):
Masked sparse matrix.
"""
if do_safety_steps:
m = mask_sparse.copy()
m.eliminate_zeros()
else:
m = mask_sparse
return (m!=0).multiply(x)
[docs]
class scipy_sparse_csr_with_length(scipy.sparse.csr_array):
"""
A scipy sparse array with a **length** attribute.
RH 2023
Attributes:
length (int):
The length of the array (shape[0])
Args:
*args (object):
Arbitrary arguments passed to scipy.sparse.csr_array.
**kwargs (object):
Arbitrary keyword arguments passed to scipy.sparse.csr_array.
"""
def __init__(
self,
*args: object,
**kwargs: object
):
"""
Initializes the scipy_sparse_csr_with_length with the given arguments and keyword arguments.
"""
super().__init__(*args, **kwargs)
self.length = self.shape[0]
def __len__(self):
return self.length
def __getitem__(self, key):
return self.__class__(super().__getitem__(key))
[docs]
class lazy_repeat_obj():
"""
Makes a lazy iterator that repeats an object.
RH 2021
Args:
obj (Any):
Object to repeat.
pseudo_length (Optional[int]):
Length of the iterator. (Default is ``None``).
"""
def __init__(
self,
obj: Any,
pseudo_length: Optional[int] = None,
):
"""
Initializes the lazy iterator.
"""
self.obj = obj
self.pseudo_length = pseudo_length
def __getitem__(self, i: int) -> Any:
"""
Get item at index `i`. Always returns the repeated object, unless index
is out of bounds.
Args:
i (int):
Index of item to return. Ignored if pseudo_length is None.
Returns:
Any: The repeated object.
Raises:
IndexError: If `i` is out of bounds.
"""
if self.pseudo_length is None:
return self.obj
elif i < self.pseudo_length:
return self.obj
else:
raise IndexError('Index out of bounds')
def __len__(self) -> int:
"""
Get the length of the iterator.
Returns:
int or None:
The length of the iterator.
"""
return self.pseudo_length
def __repr__(self):
return repr(self.item)
[docs]
def find_nonredundant_idx(
s: scipy.sparse.coo_array,
) -> np.ndarray:
"""
Finds the indices of the nonredundant entries in a sparse matrix. Useful
when manually populating a sparse matrix and you want to know which entries
have already been populated.
RH 2022
Args:
s (scipy.sparse.coo_array):
Sparse matrix. Should be in COO format.
Returns:
(np.ndarray):
idx_unique (np.ndarray):
Indices of the nonredundant entries.
"""
if s.format != 'coo':
s = s.tocoo()
idx_rowCol = np.vstack((s.row, s.col)).T
u, idx_u = np.unique(idx_rowCol, axis=0, return_index=True)
return idx_u
[docs]
def remove_redundant_elements(
s: scipy.sparse.coo_array,
inPlace: bool = False,
) -> scipy.sparse.coo_array:
"""
Removes redundant entries from a sparse matrix. Useful when manually
populating a sparse matrix and you want to remove redundant entries.
RH 2022
Args:
s (scipy.sparse.coo_array):
Sparse matrix. Should be in COO format.
inPlace (bool):
* If ``True``, the input matrix is modified in place.
* If ``False``, a new matrix is returned. \n
(Default is ``False``)
Returns:
(scipy.sparse.coo_array):
s (scipy.sparse.coo_array):
Sparse matrix with redundant entries removed.
"""
idx_nonRed = find_nonredundant_idx(s)
if inPlace:
s_out = s
else:
s_out = scipy.sparse.coo_array(s.shape)
s_out.row = s.row[idx_nonRed]
s_out.col = s.col[idx_nonRed]
s_out.data = s.data[idx_nonRed]
return s_out
[docs]
def merge_sparse_arrays(
s_list: List[scipy.sparse.csr_array],
idx_list: List[np.ndarray],
shape_full: Tuple[int, int],
remove_redundant: bool = True,
elim_zeros: bool = True
) -> scipy.sparse.csr_array:
"""
Merges a list of square sparse arrays into a single square sparse array.
Redundant entries are not selected; only entries chosen by np.unique are kept.
Args:
s_list (List[scipy.sparse.csr_array]):
List of sparse arrays to merge. Each array can be any shape.
idx_list (List[np.ndarray]):
List of integer arrays. Each array should be the same length as its
corresponding array in s_list and contain integers in the range [0,
shape_full[0]). These integers represent the row/column indices in
the full array.
shape_full (Tuple[int, int]):
Shape of the full array.
remove_redundant (bool):
* ``True``: Removes redundant entries from the output array.
* ``False``: Keeps redundant entries.
elim_zeros (bool):
* ``True``: Eliminate zeros in the sparse matrix.
* ``False``: Keeps zeros.
Returns:
scipy.sparse.csr_array:
s_full (scipy.sparse.csr_array):
Full sparse matrix merged from the input list.
"""
row, col, data = np.array([]), np.array([]), np.array([])
for s, idx in zip(s_list, idx_list):
s_i = s.tocsr() if s.format != 'csr' else s
s_i.eliminate_zeros() if elim_zeros else s_i
idx_grid = np.meshgrid(idx, idx)
row = np.concatenate([row, (s_i != 0).multiply(idx_grid[0]).data])
col = np.concatenate([col, (s_i != 0).multiply(idx_grid[1]).data])
data = np.concatenate([data, s_i.data])
s_full = scipy.sparse.coo_array((data, (row, col)), shape=shape_full)
if remove_redundant:
remove_redundant_elements(s_full, inPlace=True)
return s_full
[docs]
def scipy_sparse_to_torch_coo(
sp_array: scipy.sparse.coo_array,
dtype: Optional[type] = None
) -> torch.sparse_coo_tensor:
"""
Converts a Scipy sparse array to a PyTorch sparse COO tensor.
Args:
sp_array (scipy.sparse.coo_array):
Scipy sparse array to be converted to a PyTorch sparse COO tensor.
dtype (Optional[type]):
Data type to which the values of the input sparse array are to be
converted before creating the PyTorch sparse tensor. If ``None``,
the data type of the input array's values is retained.
(Default is ``None``).
Returns:
coo_tensor (torch.sparse_coo_tensor):
PyTorch sparse COO tensor converted from the input Scipy sparse array.
"""
import torch
coo = scipy.sparse.coo_array(sp_array)
values = coo.data
# print(values.dtype)
indices = np.vstack((coo.row, coo.col))
i = torch.as_tensor(indices, dtype=torch.int64)
# v = torch.FloatTensor(values)
v = torch.as_tensor(values, dtype=dtype) if dtype is not None else values
shape = coo.shape
return torch.sparse_coo_tensor(i, v, torch.Size(shape))
[docs]
def pydata_sparse_to_torch_coo(
sp_array: object,
) -> object:
"""
Converts a PyData Sparse array to a PyTorch sparse COO tensor.
This function extracts the coordinates and data from the sparse PyData array
and uses them to create a new sparse COO tensor in PyTorch.
Args:
sp_array (object):
The PyData Sparse array to convert. It should be a COO sparse matrix
representation.
Returns:
(object):
coo_tensor (object):
The converted PyTorch sparse COO tensor.
Example:
.. highlight:: python
.. code-block:: python
sp_array = sparse.COO(np.random.rand(1000, 1000))
coo_tensor = pydata_sparse_to_torch_coo(sp_array)
"""
coo = sparse.COO(sp_array)
values = coo.data
# indices = np.vstack((coo.row, coo.col))
indices = coo.coords
i = torch.as_tensor(indices, dtype=torch.int64)
v = torch.as_tensor(values, dtype=torch.float32)
shape = coo.shape
return torch.sparse_coo_tensor(i, v, torch.Size(shape))
[docs]
def index_with_nans(values, indices):
"""
Indexes an array with a list of indices, allowing for NaNs in the indices.
RH 2022
Args:
values (np.ndarray):
Array to be indexed.
indices (Union[List[int], np.ndarray]):
1D list or array of indices to use for indexing. Can contain NaNs.
Datatype should be floating point. NaNs will be removed and values
will be cast to int.
Returns:
np.ndarray:
Indexed array. Positions where `indices` was NaN will be filled with
NaNs.
"""
## Warn if input dtype is not NaN compatible
if not np.issubdtype(indices.dtype, np.floating):
raise ValueError('Input indices should be floating point because NaNs are used for masking. Convert to float if necessary.')
indices = np.array(indices, dtype=float) if not isinstance(indices, np.ndarray) else indices
values = np.concatenate((np.full(shape=values.shape[1:], fill_value=np.nan, dtype=values.dtype)[None,...], values), axis=0)
idx = indices.copy() + 1
idx[np.isnan(idx)] = 0
return values[idx.astype(np.int64)]
######################################################################################################################################
######################################################## FILE HELPERS ################################################################
######################################################################################################################################
[docs]
def get_nums_from_string(string_with_nums):
"""
Return the numbers from a string as an int
RH 2021-2022
Args:
string_with_nums (str):
String with numbers in it
Returns:
nums (int):
The numbers from the string
If there are no numbers, return None.
"""
_digits = set('0123456789')
idx_nums = [ch in _digits for ch in string_with_nums]
nums = []
for jj, val in enumerate(idx_nums):
if val:
nums.append(string_with_nums[jj])
if not nums:
return None
nums = int(''.join(nums))
return nums
[docs]
def find_paths(
dir_outer: Union[str, List[str]],
reMatch: str = 'filename',
reMatch_in_path: Optional[str] = None,
find_files: bool = True,
find_folders: bool = False,
depth: int = 0,
natsorted: bool = True,
alg_ns: Optional[str] = None,
verbose: bool = False,
) -> List[str]:
"""
Searches for files and/or folders recursively in a directory using a regex
match.
RH 2022
Args:
dir_outer (Union[str, List[str]]):
Path(s) to the directory(ies) to search. If a list of directories,
then all directories will be searched.
reMatch (str):
Regular expression to match. Each file or folder name encountered
will be compared using ``re.search(reMatch, filename)``. If the
output is not ``None``, the file will be included in the output.
reMatch_in_path (Optional[str]):
Additional regular expression to match anywhere in the path. Useful
for finding files/folders in specific subdirectories. If ``None``, then
no additional matching is done. \n
(Default is ``None``)
find_files (bool):
Whether to find files. (Default is ``True``)
find_folders (bool):
Whether to find folders. (Default is ``False``)
depth (int):
Maximum folder depth to search. (Default is *0*). \n
* depth=0 means only search the outer directory.
* depth=2 means search the outer directory and two levels of
subdirectories below it
natsorted (bool):
Whether to sort the output using natural sorting with the natsort
package. (Default is ``True``)
alg_ns (str):
Algorithm to use for natural sorting. See ``natsort.ns`` or
https://natsort.readthedocs.io/en/4.0.4/ns_class.html/ for options.
Default is PATH. Other commons are INT, FLOAT, VERSION. (Default is
``None``)
verbose (bool):
Whether to print the paths found. (Default is ``False``)
Returns:
(List[str]):
paths (List[str]):
Paths to matched files and/or folders in the directory.
"""
import natsort
if alg_ns is None:
alg_ns = natsort.ns.PATH
def fn_match(path, reMatch, reMatch_in_path):
# returns true if reMatch is basename and reMatch_in_path in full dirname
if reMatch is not None:
if re.search(reMatch, os.path.basename(path)) is None:
return False
if reMatch_in_path is not None:
if re.search(reMatch_in_path, os.path.dirname(path)) is None:
return False
return True
def get_paths_recursive_inner(dir_inner, depth_end, depth=0):
paths = []
for path in os.listdir(dir_inner):
path = os.path.join(dir_inner, path)
if os.path.isdir(path):
if find_folders:
if fn_match(path, reMatch, reMatch_in_path):
print(f'Found folder: {path}') if verbose else None
paths.append(path)
if depth < depth_end:
paths += get_paths_recursive_inner(path, depth_end, depth=depth+1)
else:
if find_files:
if fn_match(path, reMatch, reMatch_in_path):
print(f'Found file: {path}') if verbose else None
paths.append(path)
return paths
def fn_check_pathLike(obj):
if isinstance(obj, (
str,
Path,
os.PathLike,
np.str_,
bytes,
memoryview,
np.bytes_,
re.Pattern,
re.Match,
)):
return True
else:
return False
dir_outer = [dir_outer] if fn_check_pathLike(dir_outer) else dir_outer
paths = list(set(sum([get_paths_recursive_inner(str(d), depth, depth=0) for d in dir_outer], start=[])))
if natsorted:
paths = natsort.natsorted(paths, alg=alg_ns)
return paths
[docs]
def prepare_path(
path: str,
mkdir: bool = False,
exist_ok: bool = True,
) -> str:
"""
Checks if a directory or file path is valid for different purposes:
saving, loading, etc.
RH 2023
* If exists:
* If exist_ok=True: all good
* If exist_ok=False: raises error
* If doesn't exist:
* If file:
* If parent directory exists:
* All good
* If parent directory doesn't exist:
* If mkdir=True: creates parent directory
* If mkdir=False: raises error
* If directory:
* If mkdir=True: creates directory
* If mkdir=False: raises error
RH 2023
Args:
path (str):
Path to be checked.
mkdir (bool):
If ``True``, creates parent directory if it does not exist.
(Default is ``False``)
exist_ok (bool):
If ``True``, allows overwriting of existing file.
(Default is ``True``)
Returns:
(str):
path (str):
Resolved path.
"""
## check if path is valid
try:
path_obj = Path(path).resolve()
except FileNotFoundError as e:
print(f'Invalid path: {path}')
raise e
## check if path object exists
flag_exists = path_obj.exists()
## determine if path is a directory or file
if flag_exists:
flag_dirFileNeither = 'dir' if path_obj.is_dir() else 'file' if path_obj.is_file() else 'neither' ## 'neither' should never happen since path.is_file() or path.is_dir() should be True if path.exists()
assert flag_dirFileNeither != 'neither', f'Path: {path} is neither a file nor a directory.'
assert exist_ok, f'{path} already exists and exist_ok=False.'
else:
flag_dirFileNeither = 'dir' if path_obj.suffix == '' else 'file' ## rely on suffix to determine if path is a file or directory
## if path exists and is a file or directory
# all good. If exist_ok=False, then this should have already been caught above.
## if path doesn't exist and is a file
### if parent directory exists
# all good
### if parent directory doesn't exist
#### mkdir if mkdir=True and raise error if mkdir=False
if not flag_exists and flag_dirFileNeither == 'file':
if Path(path).parent.exists():
pass ## all good
elif mkdir:
Path(path).parent.mkdir(parents=True, exist_ok=True)
else:
assert False, f'File: {path} does not exist, Parent directory: {Path(path).parent} does not exist, and mkdir=False.'
## if path doesn't exist and is a directory
### mkdir if mkdir=True and raise error if mkdir=False
if not flag_exists and flag_dirFileNeither == 'dir':
if mkdir:
Path(path).mkdir(parents=True, exist_ok=True)
else:
assert False, f'{path} does not exist and mkdir=False.'
## if path is neither a file nor a directory
### raise error
if flag_dirFileNeither == 'neither':
assert False, f'{path} is neither a file nor a directory. This should never happen. Check this function for bugs.'
return str(path_obj)
[docs]
def prepare_filepath_for_saving(
filepath: str,
mkdir: bool = False,
allow_overwrite: bool = True
) -> str:
"""
Prepares a file path for saving a file. Ensures the file path is valid and has the necessary permissions.
Args:
filepath (str):
The file path to be prepared for saving.
mkdir (bool):
If set to ``True``, creates parent directory if it does not exist. (Default is ``False``)
allow_overwrite (bool):
If set to ``True``, allows overwriting of existing file. (Default is ``True``)
Returns:
(str):
path (str):
The prepared file path for saving.
"""
return prepare_path(filepath, mkdir=mkdir, exist_ok=allow_overwrite)
[docs]
def prepare_filepath_for_loading(
filepath: str,
must_exist: bool = True
) -> str:
"""
Prepares a file path for loading a file. Ensures the file path is valid and has the necessary permissions.
Args:
filepath (str):
The file path to be prepared for loading.
must_exist (bool):
If set to ``True``, the file at the specified path must exist. (Default is ``True``)
Returns:
(str):
path (str):
The prepared file path for loading.
"""
path = prepare_path(filepath, mkdir=False, exist_ok=must_exist)
if must_exist:
assert Path(path).is_file(), f'{path} is not a file.'
return path
[docs]
def prepare_directory_for_saving(
directory: str,
mkdir: bool = False,
exist_ok: bool = True
) -> str:
"""
Prepares a directory path for saving a file. This function is rarely used.
Args:
directory (str):
The directory path to be prepared for saving.
mkdir (bool):
If set to ``True``, creates parent directory if it does not exist. (Default is ``False``)
exist_ok (bool):
If set to ``True``, allows overwriting of existing directory. (Default is ``True``)
Returns:
(str):
path (str):
The prepared directory path for saving.
"""
return prepare_path(directory, mkdir=mkdir, exist_ok=exist_ok)
[docs]
def prepare_directory_for_loading(
directory: str,
must_exist: bool = True
) -> str:
"""
Prepares a directory path for loading a file. This function is rarely used.
Args:
directory (str):
The directory path to be prepared for loading.
must_exist (bool):
If set to ``True``, the directory at the specified path must exist. (Default is ``True``)
Returns:
(str):
path (str):
The prepared directory path for loading.
"""
path = prepare_path(directory, mkdir=False, exist_ok=must_exist)
if must_exist:
assert Path(path).is_dir(), f'{path} is not a directory.'
return path
[docs]
def pickle_save(
obj: Any,
filepath: str,
mode: str = 'wb',
zipCompress: bool = False,
mkdir: bool = False,
allow_overwrite: bool = True,
**kwargs_zipfile: Dict[str, Any],
) -> None:
"""
Saves an object to a pickle file using `pickle.dump`.
Allows for zipping of the file.
RH 2022
Args:
obj (Any):
The object to save.
filepath (str):
The path to save the object to.
mode (str):
The mode to open the file in. Options are: \n
* ``'wb'``: Write binary.
* ``'ab'``: Append binary.
* ``'xb'``: Exclusive write binary. Raises FileExistsError if the
file already exists. \n
(Default is ``'wb'``)
zipCompress (bool):
If ``True``, compresses pickle file using zipfileCompressionMethod,
which is similar to ``savez_compressed`` in numpy (with
``zipfile.ZIP_DEFLATED``). Useful for saving redundant and/or sparse
arrays objects. (Default is ``False``)
mkdir (bool):
If ``True``, creates parent directory if it does not exist. (Default
is ``False``)
allow_overwrite (bool):
If ``True``, allows overwriting of existing file. (Default is
``True``)
kwargs_zipfile (Dict[str, Any]):
Keyword arguments that will be passed into `zipfile.ZipFile`.
compression=``zipfile.ZIP_DEFLATED`` by default.
See https://docs.python.org/3/library/zipfile.html#zipfile-objects.
Other options for 'compression' are (input can be either int or object): \n
* ``0``: zipfile.ZIP_STORED (no compression)
* ``8``: zipfile.ZIP_DEFLATED (usual zip compression)
* ``12``: zipfile.ZIP_BZIP2 (bzip2 compression) (usually not as
good as ZIP_DEFLATED)
* ``14``: zipfile.ZIP_LZMA (lzma compression) (usually better
than ZIP_DEFLATED but slower)
"""
path = prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite)
if len(kwargs_zipfile)==0:
kwargs_zipfile = {
'compression': zipfile.ZIP_DEFLATED,
}
if zipCompress:
with zipfile.ZipFile(path, 'w', **kwargs_zipfile) as f:
f.writestr('data', pickle.dumps(obj))
else:
with open(path, mode) as f:
pickle.dump(obj, f)
[docs]
def pickle_load(
filepath: str,
zipCompressed: bool = False,
mode: str = 'rb',
) -> Any:
"""
Loads an object from a pickle file.
RH 2022
Args:
filepath (str):
Path to the pickle file.
zipCompressed (bool):
If ``True``, the file is assumed to be a .zip file. The function
will first unzip the file, then load the object from the unzipped
file.
(Default is ``False``)
mode (str):
The mode to open the file in. (Default is ``'rb'``)
Returns:
(Any):
obj (Any):
The object loaded from the pickle file.
"""
path = prepare_filepath_for_loading(filepath, must_exist=True)
if zipCompressed:
with zipfile.ZipFile(path, 'r') as f:
return pickle.loads(f.read('data'))
else:
with open(path, mode) as f:
return pickle.load(f)
[docs]
def json_save(
obj: Any,
filepath: str,
indent: int = 4,
mode: str = 'w',
mkdir: bool = False,
allow_overwrite: bool = True,
) -> None:
"""
Saves an object to a json file using `json.dump`.
RH 2022
Args:
obj (Any):
The object to save.
filepath (str):
The path to save the object to.
indent (int):
Number of spaces for indentation in the output json file. (Default
is *4*)
mode (str):
The mode to open the file in. Options are: \n
* ``'wb'``: Write binary.
* ``'ab'``: Append binary.
* ``'xb'``: Exclusive write binary. Raises FileExistsError if the
file already exists. \n
(Default is ``'w'``)
mkdir (bool):
If ``True``, creates parent directory if it does not exist. (Default
is ``False``)
allow_overwrite (bool):
If ``True``, allows overwriting of existing file. (Default is
``True``)
"""
import json
path = prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite)
with open(path, mode) as f:
json.dump(obj, f, indent=indent)
[docs]
def json_load(
filepath: str,
mode: str = 'r',
) -> Any:
"""
Loads an object from a json file.
RH 2022
Args:
filepath (str):
Path to the json file.
mode (str):
The mode to open the file in. (Default is ``'r'``)
Returns:
(Any):
obj (Any):
The object loaded from the json file.
"""
import json
path = prepare_filepath_for_loading(filepath, must_exist=True)
with open(path, mode) as f:
return json.load(f)
[docs]
def yaml_save(
obj: object,
filepath: str,
indent: int = 4,
mode: str = 'w',
mkdir: bool = False,
allow_overwrite: bool = True,
) -> None:
"""
Saves an object to a YAML file using the ``yaml.dump`` method.
RH 2022
Args:
obj (object):
The object to be saved.
filepath (str):
Path to save the object to.
indent (int):
The number of spaces for indentation in the saved YAML file.
(Default is *4*)
mode (str):
Mode to open the file in. \n
* ``'w'``: write (default)
* ``'wb'``: write binary
* ``'ab'``: append binary
* ``'xb'``: exclusive write binary. Raises ``FileExistsError`` if
file already exists. \n
(Default is ``'w'``)
mkdir (bool):
If ``True``, creates the parent directory if it does not exist.
(Default is ``False``)
allow_overwrite (bool):
If ``True``, allows overwriting of existing files. (Default is
``True``)
"""
path = prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite)
with open(path, mode) as f:
yaml.dump(obj, f, indent=indent, sort_keys=False)
[docs]
def yaml_load(
filepath: str,
mode: str = 'r',
loader: object = yaml.FullLoader,
) -> object:
"""
Loads a YAML file.
RH 2022
Args:
filepath (str):
Path to the YAML file to load.
mode (str):
Mode to open the file in. (Default is ``'r'``)
loader (object):
The YAML loader to use. \n
* ``yaml.FullLoader``: Loads the full YAML language. Avoids
arbitrary code execution. (Default for PyYAML 5.1+)
* ``yaml.SafeLoader``: Loads a subset of the YAML language, safely.
This is recommended for loading untrusted input.
* ``yaml.UnsafeLoader``: The original Loader code that could be
easily exploitable by untrusted data input.
* ``yaml.BaseLoader``: Only loads the most basic YAML. All scalars
are loaded as strings. \n
(Default is ``yaml.FullLoader``)
Returns:
(object):
loaded_obj (object):
The object loaded from the YAML file.
"""
path = prepare_filepath_for_loading(filepath, must_exist=True)
with open(path, mode) as f:
return yaml.load(f, Loader=loader)
[docs]
def matlab_load(
filepath: str,
simplify_cells: bool = True,
kwargs_scipy: Dict = {},
kwargs_mat73: Dict = {},
verbose: bool = False
) -> Dict:
"""
Loads a matlab file. If the .mat file is not version 7.3, it uses
``scipy.io.loadmat``. If the .mat file is version 7.3, it uses
``mat73.loadmat``. RH 2023
Args:
filepath (str):
Path to the matlab file.
simplify_cells (bool):
If set to ``True`` and file is not version 7.3, it simplifies cells
to numpy arrays. (Default is ``True``)
kwargs_scipy (Dict):
Keyword arguments to pass to ``scipy.io.loadmat``. (Default is
``{}``)
kwargs_mat73 (Dict):
Keyword arguments to pass to ``mat73.loadmat``. (Default is ``{}``)
verbose (bool):
If set to ``True``, it prints information about the file. (Default
is ``False``)
Returns:
(Dict):
out (Dict):
The loaded matlab file content in a dictionary format.
"""
path = prepare_filepath_for_loading(filepath, must_exist=True)
assert path.endswith('.mat'), 'File must be .mat file.'
try:
import scipy.io
out = scipy.io.loadmat(path, simplify_cells=simplify_cells, **kwargs_scipy)
except NotImplementedError as e:
print(f'File {path} is version 7.3. Loading with mat73.') if verbose else None
import mat73
out = mat73.loadmat(path, **kwargs_mat73)
print(f'Loaded {path} with mat73.') if verbose else None
return out
[docs]
def matlab_save(
obj: Dict,
filepath: str,
mkdir: bool = False,
allow_overwrite: bool = True,
clean_string: bool = True,
list_to_objArray: bool = True,
none_to_nan: bool = True,
kwargs_scipy_savemat: Dict = {
'appendmat': True,
'format': '5',
'long_field_names': False,
'do_compression': False,
'oned_as': 'row',
}
):
"""
Saves data to a matlab file. It uses ``scipy.io.savemat`` and provides
additional functionality such as cleaning strings, converting lists to
object arrays, and converting None to np.nan.
RH 2023
Args:
obj (Dict):
The data to save. This must be in dictionary format.
filepath (str):
The path to save the file to.
mkdir (bool):
If set to ``True``, creates parent directory if it does not exist.
(Default is ``False``)
allow_overwrite (bool):
If set to ``True``, allows overwriting of existing file. (Default is
``True``)
clean_string (bool):
If set to ``True``, converts strings to bytes. (Default is ``True``)
list_to_objArray (bool):
If set to ``True``, converts lists to object arrays. (Default is
``True``)
none_to_nan (bool):
If set to ``True``, converts None to np.nan. (Default is ``True``)
kwargs_scipy_savemat (Dict):
Keyword arguments to pass to ``scipy.io.savemat``. \n
* ``'appendmat'``: Whether to append .mat to the end of the given
filename, if it isn't already there.
* ``'format'``: The format of the .mat file. '4' for Matlab 4 .mat
files, '5' for Matlab 5 and above.
* ``'long_field_names'``: Whether to allow field names of up to 63
characters instead of the standard 31.
* ``'do_compression'``: Whether to compress matrices on write.
* ``'oned_as'``: Whether to save 1-D numpy arrays as row or column
vectors in the .mat file. 'row' or 'column'. \n
(Default is ``{'appendmat': True, 'format': '5', 'long_field_names':
False, 'do_compression': False, 'oned_as': 'row'}``)
"""
import numpy as np
prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite)
def walk(d, fn):
return {key: fn(val) if isinstance(val, dict)==False else walk(val, fn) for key, val in d.items()}
fn_clean_string = (lambda x: x.encode('utf-8') if isinstance(x, str) and clean_string else x) if clean_string else (lambda x: x)
fn_list_to_objArray = (lambda x: np.array(x, dtype=object) if isinstance(x, list) and list_to_objArray else x) if list_to_objArray else (lambda x: x)
fn_none_to_nan = (lambda x: np.nan if x is None and none_to_nan else x) if none_to_nan else (lambda x: x)
data_cleaned = walk(walk(walk(obj, fn_clean_string), fn_list_to_objArray), fn_none_to_nan)
import scipy.io
scipy.io.savemat(filepath, data_cleaned, **kwargs_scipy_savemat)
[docs]
def download_file(
url: Optional[str],
path_save: str,
check_local_first: bool = True,
check_hash: bool = False,
hash_type: str = 'MD5',
hash_hex: Optional[str] = None,
mkdir: bool = False,
allow_overwrite: bool = True,
write_mode: str = 'wb',
verbose: bool = True,
chunk_size: int = 1024,
) -> None:
"""
Downloads a file from a URL to a local path using requests. Checks if file
already exists locally and verifies the hash of the downloaded file against
a provided hash if required.
RH 2023
Args:
url (Optional[str]):
URL of the file to download. If ``None``, then no download is
attempted. (Default is ``None``)
path_save (str):
Path to save the file to.
check_local_first (bool):
Whether to check if the file already exists locally. If ``True`` and
the file exists locally, the download is skipped. If ``True`` and
``check_hash`` is also ``True``, the hash of the local file is
checked. If the hash matches, the download is skipped. If the hash
does not match, the file is downloaded. (Default is ``True``)
check_hash (bool):
Whether to check the hash of the local or downloaded file against
``hash_hex``. (Default is ``False``)
hash_type (str):
Type of hash to use. Options are: ``'MD5'``, ``'SHA1'``,
``'SHA256'``, ``'SHA512'``. (Default is ``'MD5'``)
hash_hex (Optional[str]):
Hash to compare to, in hexadecimal format (e.g., 'a1b2c3d4e5f6...').
Can be generated using ``hash_file()`` or ``hashlib.hexdigest()``.
If ``check_hash`` is ``True``, ``hash_hex`` must be provided.
(Default is ``None``)
mkdir (bool):
If ``True``, creates the parent directory of ``path_save`` if it
does not exist. (Default is ``False``)
write_mode (str):
Write mode for saving the file. Options include: ``'wb'`` (write
binary), ``'ab'`` (append binary), ``'xb'`` (write binary, fail if
file exists). (Default is ``'wb'``)
verbose (bool):
If ``True``, prints status messages. (Default is ``True``)
chunk_size (int):
Size of chunks in which to download the file. (Default is *1024*)
"""
import os
import requests
# Check if file already exists locally
if check_local_first:
if os.path.isfile(path_save):
print(f'File already exists locally: {path_save}') if verbose else None
# Check hash of local file
if check_hash:
hash_local = hash_file(path_save, type_hash=hash_type)
if hash_local == hash_hex:
print('Hash of local file matches provided hash_hex.') if verbose else None
return True
else:
print('Hash of local file does not match provided hash_hex.') if verbose else None
print(f'Hash of local file: {hash_local}') if verbose else None
print(f'Hash provided in hash_hex: {hash_hex}') if verbose else None
print('Downloading file...') if verbose else None
else:
return True
else:
print(f'File does not exist locally: {path_save}. Will attempt download from {url}') if verbose else None
# Download file
if url is None:
print('No URL provided. No download attempted.') if verbose else None
return None
## Retry with exponential backoff for transient HTTP errors (403 rate
## limits from OSF, 429, 5xx). Max ~30s total wait.
import time
max_retries = 4
retryable_codes = {403, 429, 500, 502, 503, 504}
response = None
for attempt in range(max_retries + 1):
try:
response = requests.get(url, stream=True)
except requests.exceptions.RequestException as e:
print(f'Error downloading file (attempt {attempt + 1}/{max_retries + 1}): {e}') if verbose else None
if attempt < max_retries:
wait = 2 ** attempt
print(f'Retrying in {wait}s...') if verbose else None
time.sleep(wait)
continue
return False
if response.status_code == 200:
break
if response.status_code in retryable_codes and attempt < max_retries:
wait = 2 ** attempt
print(f'Error downloading file. Response status code: {response.status_code}. Retrying in {wait}s (attempt {attempt + 1}/{max_retries + 1})...') if verbose else None
time.sleep(wait)
continue
## Non-retryable error or out of retries
print(f'Error downloading file. Response status code: {response.status_code}') if verbose else None
return False
# Create parent directory if it does not exist
prepare_filepath_for_saving(path_save, mkdir=mkdir, allow_overwrite=allow_overwrite)
# Download file with progress bar
total_size = int(response.headers.get('content-length', 0))
wrote = 0
with open(path_save, write_mode) as f:
with tqdm(total=total_size, disable=(verbose==False), unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for data in response.iter_content(chunk_size):
wrote = wrote + len(data)
f.write(data)
pbar.update(len(data))
if total_size != 0 and wrote != total_size:
print("ERROR, something went wrong")
return False
# Check hash
hash_local = hash_file(path_save, type_hash=hash_type)
if check_hash:
if hash_local == hash_hex:
print('Hash of downloaded file matches hash_hex.') if verbose else None
return True
else:
print('Hash of downloaded file does not match hash_hex.') if verbose else None
print(f'Hash of downloaded file: {hash_local}') if verbose else None
print(f'Hash provided in hash_hex: {hash_hex}') if verbose else None
return False
else:
print(f'Hash of downloaded file: {hash_local}') if verbose else None
return True
[docs]
def hash_file(
path: str,
type_hash: str = 'MD5',
buffer_size: int = 65536,
) -> str:
"""
Computes the hash of a file using the specified hash type and buffer size.
RH 2022
Args:
path (str):
Path to the file to be hashed.
type_hash (str):
Type of hash to use. (Default is ``'MD5'``). Either \n
* ``'MD5'``: MD5 hash algorithm.
* ``'SHA1'``: SHA1 hash algorithm.
* ``'SHA256'``: SHA256 hash algorithm.
* ``'SHA512'``: SHA512 hash algorithm.
buffer_size (int):
Buffer size (in bytes) for reading the file.
65536 corresponds to 64KB. (Default is *65536*)
Returns:
(str):
hash_val (str):
The computed hash of the file.
"""
import hashlib
if type_hash == 'MD5':
hasher = hashlib.md5()
elif type_hash == 'SHA1':
hasher = hashlib.sha1()
elif type_hash == 'SHA256':
hasher = hashlib.sha256()
elif type_hash == 'SHA512':
hasher = hashlib.sha512()
else:
raise ValueError(f'{type_hash} is not a valid hash type.')
with open(path, 'rb') as f:
while True:
data = f.read(buffer_size)
if not data:
break
hasher.update(data)
hash_val = hasher.hexdigest()
return hash_val
[docs]
def get_dir_contents(
directory: str,
) -> Tuple[List[str], List[str]]:
"""
Retrieves the names of the folders and files in a directory (does not
include subdirectories).
RH 2021
Args:
directory (str):
The path to the directory.
Returns:
(tuple): tuple containing:
folders (List[str]):
A list of folder names.
files (List[str]):
A list of file names.
"""
walk = os.walk(directory, followlinks=False)
folders = []
files = []
for ii,level in enumerate(walk):
folders, files = level[1:]
if ii==0:
break
return folders, files
[docs]
def compare_file_hashes(
hash_dict_true: Dict[str, Tuple[str, str]],
dir_files_test: Optional[str] = None,
paths_files_test: Optional[List[str]] = None,
verbose: bool = True,
) -> Tuple[bool, Dict[str, bool], Dict[str, str]]:
"""
Compares hashes of files in a directory or list of paths to provided hashes.
RH 2022
Args:
hash_dict_true (Dict[str, Tuple[str, str]]):
Dictionary of hashes to compare. Each entry should be in the format:
*{'key': ('filename', 'hash')}*.
dir_files_test (str):
Path to directory containing the files to compare hashes.
Unused if paths_files_test is not ``None``. (Optional)
paths_files_test (List[str]):
List of paths to files to compare hashes.
dir_files_test is used if ``None``. (Optional)
verbose (bool):
If ``True``, failed comparisons are printed out. (Default is ``True``)
Returns:
(tuple): tuple containing:
total_result (bool):
``True`` if all hashes match, ``False`` otherwise.
individual_results (Dict[str, bool]):
Dictionary indicating whether each hash matched.
paths_matching (Dict[str, str]):
Dictionary of paths that matched. Each entry is in the format:
*{'key': 'path'}*.
"""
if paths_files_test is None:
if dir_files_test is None:
raise ValueError('Must provide either dir_files_test or path_files_test.')
## make a dict of {filename: path} for each file in dir_files_test
files_test = {filename: (Path(dir_files_test).resolve() / filename).as_posix() for filename in get_dir_contents(dir_files_test)[1]}
else:
files_test = {Path(path).name: path for path in paths_files_test}
paths_matching = {}
results_matching = {}
for key, (filename, hash_true) in hash_dict_true.items():
match = True
if filename not in files_test:
print(f'{filename} not found in test directory: {dir_files_test}.') if verbose else None
match = False
elif hash_true != hash_file(files_test[filename]):
print(f'{filename} hash mismatch with {key, filename}.') if verbose else None
match = False
if match:
paths_matching[key] = files_test[filename]
results_matching[key] = match
return all(results_matching.values()), results_matching, paths_matching
######################################################################################################################################
###################################################### PLOTTING HELPERS ##############################################################
######################################################################################################################################
[docs]
def plot_image_grid(
images: Union[List[np.ndarray], np.ndarray],
labels: Optional[List[str]] = None,
grid_shape: Tuple[int, int] = (10,10),
show_axis: str = 'off',
cmap: Optional[str] = None,
kwargs_subplots: Dict = {},
kwargs_imshow: Dict = {},
) -> Tuple[plt.Figure, Union[np.ndarray, plt.Axes]]:
"""
Plots a grid of images.
RH 2021
Args:
images (Union[List[np.ndarray], np.ndarray]):
A list of images or a 3D array of images, where the first dimension is the number of images.
labels (Optional[List[str]]):
A list of labels to be displayed in the grid. (Default is ``None``)
grid_shape (Tuple[int, int]):
Shape of the grid. (Default is *(10,10)*)
show_axis (str):
Whether to show axes or not. (Default is 'off')
cmap (Optional[str]):
Colormap to use. (Default is ``None``)
kwargs_subplots (Dict):
Keyword arguments for subplots. (Default is {})
kwargs_imshow (Dict):
Keyword arguments for imshow. (Default is {})
Returns:
(Tuple[plt.Figure, Union[np.ndarray, plt.Axes]]): tuple containing:
fig (plt.Figure):
Figure object.
axs (Union[np.ndarray, plt.Axes]):
Axes object.
"""
if cmap is None:
cmap = 'viridis'
fig, axs = plt.subplots(nrows=grid_shape[0], ncols=grid_shape[1], **kwargs_subplots)
axs_flat = axs.flatten(order='F') if isinstance(axs, np.ndarray) else [axs]
for ii, ax in enumerate(axs_flat[:len(images)]):
ax.imshow(images[ii], cmap=cmap, **kwargs_imshow);
if labels is not None:
ax.set_title(labels[ii]);
ax.axis(show_axis);
return fig, axs
[docs]
def rand_cmap(
nlabels: int,
first_color_black: bool = False,
last_color_black: bool = False,
verbose: bool = True,
under: List[float] = [0,0,0],
over: List[float] = [0.5,0.5,0.5],
bad: List[float] = [0.9,0.9,0.9],
) -> object:
"""
Creates a random colormap to be used with matplotlib. Useful for segmentation tasks.
Args:
nlabels (int):
Number of labels (size of colormap).
first_color_black (bool):
Option to use the first color as black. (Default is ``False``)
last_color_black (bool):
Option to use the last color as black. (Default is ``False``)
verbose (bool):
Prints the number of labels and shows the colormap if ``True``.
(Default is ``True``)
under (List[float]):
RGB values to use for the 'under' threshold in the colormap.
(Default is ``[0, 0, 0]``)
over (List[float]):
RGB values to use for the 'over' threshold in the colormap.
(Default is ``[0.5, 0.5, 0.5]``)
bad (List[float]):
RGB values to use for 'bad' values in the colormap.
(Default is ``[0.9, 0.9, 0.9]``)
Returns:
(LinearSegmentedColormap):
colormap (LinearSegmentedColormap):
Colormap for matplotlib.
"""
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
import numpy as np
assert nlabels > 0, 'Number of labels must be greater than 0'
if verbose:
print('Number of labels: ' + str(nlabels))
randRGBcolors = np.random.rand(nlabels, 3)
randRGBcolors = randRGBcolors / np.max(randRGBcolors, axis=1, keepdims=True)
if first_color_black:
randRGBcolors[0] = [0, 0, 0]
if last_color_black:
randRGBcolors[-1] = [0, 0, 0]
random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
# Display colorbar
if verbose:
from matplotlib import colors, colorbar
fig, ax = plt.subplots(1, 1, figsize=(6, 0.5))
bounds = np.linspace(0, nlabels, nlabels + 1)
norm = colors.BoundaryNorm(bounds, nlabels)
cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
boundaries=bounds, format='%1i', orientation=u'horizontal')
random_colormap.set_bad(bad)
random_colormap.set_over(over)
random_colormap.set_under(under)
return random_colormap
[docs]
def simple_cmap(
colors: List[List[float]] = [
[1,0,0],
[1,0.6,0],
[0.9,0.9,0],
[0.6,1,0],
[0,1,0],
[0,1,0.6],
[0,0.8,0.8],
[0,0.6,1],
[0,0,1],
[0.6,0,1],
[0.8,0,0.8],
[1,0,0.6],
],
under: List[float] = [0,0,0],
over: List[float] = [0.5,0.5,0.5],
bad: List[float] = [0.9,0.9,0.9],
name: str = 'none',
) -> object:
"""
Creates a colormap from a sequence of RGB values.
Borrowed with permission from Alex
(https://gist.github.com/ahwillia/3e022cdd1fe82627cbf1f2e9e2ad80a7ex)
Args:
colors (List[List[float]]):
List of RGB values. Each sub-list contains three float numbers
representing an RGB color. (Default is list of RGB colors ranging from red to purple)
under (List[float]):
RGB values for the colormap under range.
(Default is ``[0,0,0]`` (black))
over (List[float]):
RGB values for the colormap over range.
(Default is ``[0.5,0.5,0.5]`` (grey))
bad (List[float]):
RGB values for the colormap bad range.
(Default is ``[0.9,0.9,0.9]`` (light grey))
name (str):
Name of the colormap. (Default is 'none')
Returns:
(LinearSegmentedColormap):
cmap (LinearSegmentedColormap):
The generated colormap.
Example:
.. highlight:: python
.. code-block:: python
cmap = simple_cmap([(1,1,1), (1,0,0)]) # white to red colormap
cmap = simple_cmap(['w', 'r']) # white to red colormap
cmap = simple_cmap(['r', 'b', 'r']) # red to blue to red
"""
from matplotlib.colors import LinearSegmentedColormap, colorConverter
# check inputs
n_colors = len(colors)
if n_colors <= 1:
raise ValueError('Must specify at least two colors')
# convert colors to rgb
colors = [colorConverter.to_rgb(c) for c in colors]
# set up colormap
r, g, b = colors[0]
cdict = {'red': [(0.0, r, r)], 'green': [(0.0, g, g)], 'blue': [(0.0, b, b)]}
for i, (r, g, b) in enumerate(colors[1:]):
idx = (i+1) / (n_colors-1)
cdict['red'].append((idx, r, r))
cdict['green'].append((idx, g, g))
cdict['blue'].append((idx, b, b))
cmap = LinearSegmentedColormap(name, {k: tuple(v) for k, v in cdict.items()})
cmap.set_bad(bad)
cmap.set_over(over)
cmap.set_under(under)
return cmap
[docs]
class ImageLabeler:
"""
A simple graphical interface for labeling image classes. Use this class with
a context manager to ensure the window is closed properly. The class
provides a tkinter window which displays images from a provided numpy array
one by one and lets you classify each image by pressing a key. The title of
the window is the image index. The classification label and image index are
stored as the ``self.labels_`` attribute and saved to a CSV file in
self.path_csv.
RH 2023
Args:
image_array (np.ndarray):
A numpy array of images. Either 3D: *(n_images, height, width)* or
4D: *(n_images, height, width, n_channels)*. Images should be scaled
between 0 and 255 and will be converted to uint8.
start_index (int):
The index of the first image to display. (Default is *0*)
path_csv (str):
Path to the CSV file for saving results. If ``None``, results will
not be saved.
save_csv (bool):
Whether to save the results to a CSV. (Default is ``True``)
resize_factor (float):
A scaling factor for the fractional change in image size. (Default
is *1.0*)
normalize_images (bool):
Whether to normalize the images between min and max values. (Default
is ``True``)
verbose (bool):
Whether to print status updates. (Default is ``True``)
key_end (str):
Key to press to end the session. (Default is ``'Escape'``)
key_prev (str):
Key to press to go back to the previous image. (Default is
``'Left'``)
key_next (str):
Key to press to go to the next image. (Default is ``'Right'``)
Example:
.. highlight:: python
.. code-block:: python
with ImageLabeler(images, start_index=0, resize_factor=4.0,
key_end='Escape') as labeler:
labeler.run()
path_csv, labels = labeler.path_csv, labeler.labels_
Attributes:
image_array (np.ndarray):
A numpy array of images. Either 3D: *(n_images, height, width)* or
4D: *(n_images, height, width, n_channels)*. Images should be scaled
between 0 and 255 and will be converted to uint8.
start_index (int):
The index of the first image to display. (Default is *0*)
path_csv (str):
Path to the CSV file for saving results. If ``None``, results will
not be saved.
save_csv (bool):
Whether to save the results to a CSV. (Default is ``True``)
resize_factor (float):
A scaling factor for the fractional change in image size. (Default
is *1.0*)
normalize_images (bool):
Whether to normalize the images between min and max values. (Default
is ``True``)
verbose (bool):
Whether to print status updates. (Default is ``True``)
key_end (str):
Key to press to end the session. (Default is ``'Escape'``)
key_prev (str):
Key to press to go back to the previous image. (Default is
``'Left'``)
key_next (str):
Key to press to go to the next image. (Default is ``'Right'``)
labels_ (list):
A list of tuples containing the image index and classification label
for each image. The list is saved to a CSV file in self.path_csv.
"""
def __init__(
self,
image_array: np.ndarray,
start_index: int=0,
path_csv: Optional[str] = None,
save_csv: bool = True,
resize_factor: float = 10.0,
normalize_images: bool = True,
verbose: bool = True,
key_end: str = 'Escape',
key_prev: str = 'Left',
key_next: str = 'Right',
) -> None:
"""
Initializes the ImageLabeler with the given image array, csv path, and UI
elements. Binds keys for classifying images and ending the session.
"""
import tempfile
import datetime
## Set attributes
self.images = image_array
self._resize_factor = resize_factor
self._index = start_index - 1 ## -1 because we increment before displaying
self.path_csv = path_csv if path_csv is not None else str(Path(tempfile.gettempdir()) / ('roicat_labels_' + datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + '.csv'))
self._save_csv = save_csv
self.labels_ = {}
self._img_tk = None
self._key_end = key_end if key_end is not None else None
self._key_prev = key_prev if key_prev is not None else None
self._key_next = key_next if key_next is not None else None
self._normalize_images = normalize_images
self._verbose = verbose
self.__call__ = self.run
[docs]
def run(self):
"""
Runs the image labeler. Opens a tkinter window and displays the first
image.
"""
try:
self._root = tk.Tk()
self._img_label = tk.Label(self._root)
self._img_label.pack()
## Bind keys
self._root.bind("<Key>", self.classify)
self._root.bind('<Key-' + self._key_end + '>', self.end_session) if self._key_end is not None else None
self._root.bind('<Key-' + self._key_prev + '>', self.prev_img) if self._key_prev is not None else None
self._root.bind('<Key-' + self._key_next + '>', self.next_img) if self._key_next is not None else None
self._root.protocol("WM_DELETE_WINDOW", self._on_closing)
## Start the session
self.next_img()
self._root.mainloop()
except Exception as e:
warnings.warn('Error initializing image labeler: ' + str(e))
def _on_closing(self):
from tkinter import messagebox
if messagebox.askokcancel("Quit", "Do you want to quit?"):
self.end_session(None)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_session(None)
[docs]
def next_img(self, event=None):
"""Displays the next image in the array, and resizes the image."""
## Display the image
### End the session if there are no more images
self._index += 1
if self._index < len(self.images):
im = self.images[self._index]
im = (im / np.max(im)) * 255 if self._normalize_images else im
pil_img = PIL.Image.fromarray(np.uint8(im)) ## Convert to uint8 and PIL image
## Resize image
width, height = pil_img.size
new_width = int(width * self._resize_factor)
new_height = int(height * self._resize_factor)
pil_img = pil_img.resize((new_width, new_height), resample=PIL.Image.Resampling.LANCZOS)
## Display image
self._img_tk = ImageTk.PhotoImage(pil_img)
self._img_label.image = self._img_tk # keep a reference to the PhotoImage object
self._img_label.config(image=self._img_label.image)
else:
self.end_session(None)
self._root.title(str(self._index)) # update the window title to the current image index
[docs]
def prev_img(self, event=None):
"""
Displays the previous image in the array.
"""
self._index -= 2
self.next_img()
[docs]
def classify(self, event):
"""
Adds the current image index and pressed key as a label.
Then saves the results and moves to the next image.
Args:
event (tkinter.Event):
A tkinter event object.
"""
label = event.char
if label != '':
print(f'Image {self._index}: {label}') if self._verbose else None
self.labels_.update({self._index: str(label)}) ## Store the label
self.save_classification() if self._save_csv else None ## Save the results
self.next_img() ## Move to the next image
[docs]
def end_session(self, event):
"""
Ends the classification session by destroying the tkinter window.
"""
self._img_tk = None
self._root.destroy() if self._root is not None else None
self._root = None
import gc
gc.collect()
gc.collect()
[docs]
def save_classification(self):
"""
Saves the classification results to a CSV file.
This function does not append, it overwrites the entire file.
The file contains two columns: 'image_index' and 'label'.
"""
## make directory if it doesn't exist
Path(self.path_csv).parent.mkdir(parents=True, exist_ok=True)
## Save the results
with open(self.path_csv, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(('image_index', 'label'))
writer.writerows(self.labels_.items())
[docs]
def get_labels(self, kind: str = 'dict') -> Union[dict, List[Tuple[int, str]], dict]:
"""
Returns the labels. The format of the output is determined by the ``kind`` parameter.
If the labels dictionary is empty, returns ``None``. RH 2023
Args:
kind (str):
The type of object to return. (Default is ``'dict'``) \n
* ``'dict'``: {idx: label, idx: label, ...}
* ``'list'``: [label, label, ...] where the index is the image
index and unlabeled images are represented as ``'None'``.
* ``'dataframe'``: {'index': [idx, idx, ...], 'label': [label, label, ...]}
This can be converted to a pandas dataframe with:
pd.DataFrame(self.get_labels('dataframe'))
Returns:
(Union[dict, List[Tuple[int, str]], dict]):
Depending on the ``kind`` parameter, it returns either: \n
* dict:
A dictionary where keys are the image indices and values are
the labels.
* List[Tuple[int, str]]:
A list of tuples, where each tuple contains an image index
and a label.
* dict:
A dictionary with keys 'index' and 'label' where values are
lists of indices and labels respectively.
"""
## if the dict is empty, return None
if len(self.labels_) == 0:
return None
if kind == 'dict':
return self.labels_
elif kind == 'list':
out = ['None',] * len(self.images)
for idx, label in self.labels_.items():
out[idx] = label
return out
elif kind == 'dataframe':
import pandas as pd
return pd.DataFrame(index=list(self.labels_.keys()), data={'label': list(self.labels_.values())})
[docs]
class IntegratedLabeler:
"""
A graphical interface for labeling image classes. The class displays a
sequence of images in the left panel which can be labelled by pressing keys
and the right panel is a scatterplot of an embedding of each image with the
option to overlay images on the scatterplot. The user can use a lasso tool
to select points on the scatterplot and these points will be shown to the
user on the left panel for labelling.
The title of the window is the current image index. The overlays can be
toggled by pressing Control-Shift-T. The classification label and image
index are stored as the ``self.labels_`` attribute and saved to a CSV file
in self.path_csv.
Args:
images (np.ndarray):
A numpy array of images. Either 3D: *(n_images, height, width)* or
4D: *(n_images, height, width, n_channels)*. Images should be scaled
between 0 and 255 and will be converted to uint8.
embeddings (np.ndarray):
A numpy array of embeddings for each image. Should be shape
*(n_images, 2)*.
idx_images_overlay (np.ndarray):
A numpy array of indices of images to overlay on the scatterplot.
idx_selection (List[int]):
A list of indices to select from the image array. If ``None``, all
images will be selected. (Default is ``None``)
figsize (float):
The size of each panel in the figure (width and height). (Default is *5*)
size_images_overlay (Optional[float]):
The size of the images to overlay. If ``None``, the size is calculated
based on nearest neighbors. (Default is ``None``)
crop_images_overlay (Optional[float]):
The fraction of the image to crop on each side. (Default is *0.35*)
frac_overlap_allowed (float):
The fraction of overlap allowed between images. (Default is *0.5*)
image_overlay_raster_size (Tuple[int, int]):
The size of the raster for the composite overlay.
(Default is*(1000, 1000)*)
alpha_points (float):
The transparency of the scatterplot points. (Default is *0.5*)
size_points (float):
The size of the scatterplot points. (Default is *20*)
normalize_images (bool):
Whether to normalize the images between min and max values. (Default
is ``True``)
verbose (bool):
Whether to print status updates. (Default is ``True``)
path_csv (Optional[str]):
Path to the CSV file for saving results. If ``None``, results will
not be saved.
save_csv (bool):
Whether to save the results to a CSV. (Default is ``True``)
key_end (str):
Key to press to end the session. (Default is ``'Escape'``)
key_prev (str):
Key to press to go back to the previous image. (Default is
``'Left'``)
key_next (str):
Key to press to go to the next image. (Default is ``'Right'``)
Example:
.. highlight:: python
.. code-block:: python
with IntegratedLabeler(
images,
embeddings=emb,
idx_images_overlay=idx_images_overlay,
size_images_overlay=0.25,
frac_overlap_allowed=0.25,
crop_images_overlay=0.25,
alpha_points=1.0,
size_points=3.0,
) as labeler:
labeler.run()
path_csv, labels = labeler.path_csv, labeler.labels_
Attributes:
path_csv (str):
Path to the CSV file for saving results. If ``None``, results will
not be saved.
save_csv (bool):
Whether to save the results to a CSV. (Default is ``True``)
labels_ (list):
A list of tuples containing the image index and classification label
for each image. The list is saved to a CSV file in self.path_csv.
"""
def __init__(
self,
images: np.ndarray,
embeddings: np.ndarray = None,
idx_images_overlay: Optional[np.ndarray] = None,
idx_selection: Optional[List[int]] = None,
figsize: float = 5,
size_images_overlay: Optional[float] = None,
crop_images_overlay: Optional[float] = 0.35,
frac_overlap_allowed: float = 0.5,
image_overlay_raster_size: Tuple[int, int] = (1000, 1000),
alpha_points: float = 0.5,
size_points: float = 20,
normalize_images: bool = True,
verbose: bool = True,
path_csv: Optional[str] = None,
save_csv: bool = True,
key_end: str = 'Escape',
key_prev: str = 'Left',
key_next: str = 'Right',
):
"""Build the IntegratedLabeler Object."""
# Data attributes
self.images = images
self.embeddings = embeddings
self.idx_images_overlay = idx_images_overlay
# Plotting properties
self.size_images_overlay = size_images_overlay
self.crop_images_overlay = crop_images_overlay
self.frac_overlap_allowed = frac_overlap_allowed
self.image_overlay_raster_size = image_overlay_raster_size
self.alpha_points = alpha_points
self.size_points = size_points
self._normalize_images = normalize_images
self.figsize = figsize
self._verbose = verbose
# Saving properties
import tempfile
import datetime
self.path_csv = path_csv if path_csv is not None else str(Path(tempfile.gettempdir()) / ('roicat_labels_' + datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + '.csv'))
self._save_csv = save_csv
# Initialize the selection and current index
self._index = -1 # Start at -1 so that the first image is 0
self.update_selection(idx_selection)
# Calculate data limits with padding
self.data_limits = self._calculate_data_limits()
# Create composite overlay if idx_images_overlay provided
if self.idx_images_overlay is not None:
self._create_composite_overlay()
self.show_overlay = True
# Results will be stored here
self.labels_ = {}
# Initialize GUI elements
self._img_tk = None
self._key_end = key_end
self._key_prev = key_prev
self._key_next = key_next
self._root = None
self.__call__ = self.run
[docs]
def run(self):
"""
Runs the image labeler with both image display and matplotlib panels.
"""
try:
self._root = tk.Tk()
self._root.title("Image Labeler")
# Create main container
main_container = ttk.PanedWindow(self._root, orient=tk.HORIZONTAL)
main_container.pack(fill=tk.BOTH, expand=True, padx=3, pady=3)
# Left panel for image display
left_panel = ttk.Frame(main_container)
self._img_fig, self._img_ax = plt.subplots(figsize=(self.figsize, self.figsize))
self._img_canvas = FigureCanvasTkAgg(self._img_fig, master=left_panel)
self._img_canvas.draw()
self._img_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
main_container.add(left_panel)
# Right panel for scatterplot
right_panel = ttk.Frame(main_container)
self._scatter_fig, self._scatter_ax = plt.subplots(figsize=(self.figsize, self.figsize))
self._scatter_canvas = FigureCanvasTkAgg(self._scatter_fig, master=right_panel)
self._scatter_canvas.draw()
self._scatter_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
main_container.add(right_panel)
# Make figures fill the space
self._img_fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.98)
self._scatter_fig.subplots_adjust(left=0.02, right=0.98, bottom=0.02, top=0.98)
# Build scatter plot
self._build_scatter_plot()
# Bind keys
self._root.bind("<Key>", self.classify)
if self._key_end:
self._root.bind(f'<Key-{self._key_end}>', self.end_session)
if self._key_prev:
self._root.bind(f'<Key-{self._key_prev}>', self.prev_img)
if self._key_next:
self._root.bind(f'<Key-{self._key_next}>', self.next_img)
if self.idx_images_overlay is not None:
self._root.bind(f'<Control-T>', self._toggle_overlay)
self._root.protocol("WM_DELETE_WINDOW", self._on_closing)
# Start the session
self.next_img()
self._root.mainloop()
except Exception as e:
warnings.warn('Error initializing image labeler: ' + str(e))
[docs]
def update_selection(self, idx_selection: List[int]):
"""
Updates the selection of images to classify. The selection is a list of
indices to select from the image array. Will show the first image in the
new selection.
Args:
idx_selection (List[int]):
A list of indices to select from the image array.
"""
if idx_selection is not None:
# Check if provided list is valid
if min(idx_selection) < 0 or max(idx_selection) >= len(self.images):
raise ValueError('idx_selection exceeds range of images (must be in [0, len(image_array)-1]).')
self._idx_selection = idx_selection if idx_selection is not None else list(range(len(self.images)))
self._index = -1
if hasattr(self, '_root') and self._root is not None:
# Only attempt to go to next image if the window is open
self.next_img()
def _on_closing(self):
from tkinter import messagebox
if messagebox.askokcancel("Quit", "Do you want to quit?"):
self.end_session(None)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_session(None)
def _get_current_idx(self):
"""Central method for getting the current image index."""
return int(self._idx_selection[self._index])
[docs]
def next_img(self, event=None):
"""Displays the next image and updates the matplotlib plot."""
self._index += 1
if self._index < len(self._idx_selection):
# Update image display
im = self.images[self._get_current_idx()]
im = (im / np.max(im)) * 255 if self._normalize_images else im
# Update data of _img_ax
self._img_ax.clear()
self._img_ax.imshow(im, cmap='gray')
self._img_ax.text(0, 0, "Press any key to label ROI\nPress left or right to switch ROIs", color="white", fontsize=12, ha="left", va="top")
self._img_ax.set_xticks([])
self._img_ax.set_yticks([])
self._img_canvas.draw()
self._root.title(str(self._get_current_idx()))
if hasattr(self, 'scatter'):
# Update colors to show selection
colors = np.array(['gray'] * len(self.embeddings))
colors[list(self._idx_selection)] = 'red'
self.scatter.set_color(colors)
# Update current scatter point
self.current_scatter.set_offsets([self.embeddings[self._get_current_idx(), :]])
self._scatter_canvas.draw()
else:
# Loop index back to start of current selection
self._index = 0
[docs]
def prev_img(self, event=None):
"""
Displays the previous image in the array.
"""
self._index -= 2
self.next_img()
[docs]
def end_session(self, event):
"""Ends the session and cleans up matplotlib resources."""
try:
if self._root is not None:
self._root.quit()
self._root.destroy()
self._root = None
except:
pass
try:
if hasattr(self, '_img_fig') and self._img_fig is not None:
plt.close(self._img_fig)
if hasattr(self, '_scatter_fig') and self._scatter_fig is not None:
plt.close(self._scatter_fig)
self._img_fig = None
self._scatter_fig = None
self._img_canvas = None
self._scatter_canvas = None
import gc
gc.collect()
except:
pass
import gc
gc.collect()
gc.collect()
[docs]
def save_classification(self):
"""
Saves the classification results to a CSV file.
This function does not append, it overwrites the entire file.
The file contains two columns: 'image_index' and 'label'.
"""
## make directory if it doesn't exist
Path(self.path_csv).parent.mkdir(parents=True, exist_ok=True)
## Save the results
with open(self.path_csv, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(('image_index', 'label'))
writer.writerows(self.labels_.items())
[docs]
def get_labels(self, kind: str = 'dict') -> Union[dict, List[Tuple[int, str]], dict]:
"""
Returns the labels. The format of the output is determined by the ``kind`` parameter.
If the labels dictionary is empty, returns ``None``. RH 2023
Args:
kind (str):
The type of object to return. (Default is ``'dict'``) \n
* ``'dict'``: {idx: label, idx: label, ...}
* ``'list'``: [label, label, ...] where the index is the image
index and unlabeled images are represented as ``'None'``.
* ``'dataframe'``: {'index': [idx, idx, ...], 'label': [label, label, ...]}
This can be converted to a pandas dataframe with:
pd.DataFrame(self.get_labels('dataframe'))
Returns:
(Union[dict, List[Tuple[int, str]], dict]):
Depending on the ``kind`` parameter, it returns either: \n
* dict:
A dictionary where keys are the image indices and values are
the labels.
* List[Tuple[int, str]]:
A list of tuples, where each tuple contains an image index
and a label.
* dict:
A dictionary with keys 'index' and 'label' where values are
lists of indices and labels respectively.
"""
## if the dict is empty, return None
if len(self.labels_) == 0:
return None
if kind == 'dict':
return self.labels_
elif kind == 'list':
out = ['None',] * len(self.images)
for idx, label in self.labels_.items():
out[idx] = label
return out
elif kind == 'dataframe':
import pandas as pd
return pd.DataFrame(index=list(self.labels_.keys()), data={'label': list(self.labels_.values())})
[docs]
def classify(self, event):
"""
Adds the current image index and pressed key as a label.
Then saves the results and moves to the next image.
Args:
event (tkinter.Event):
A tkinter event object.
"""
label = event.char
if label != '':
print(f'Image {self._get_current_idx()}: {str(repr(label))}') if self._verbose else None
self.labels_.update({self._get_current_idx(): str(label)}) ## Store the label
self.save_classification() if self._save_csv else None ## Save the results
self.next_img() ## Move to the next image
def _toggle_overlay(self, event):
"""Toggle the overlay on and off"""
if self.idx_images_overlay is not None:
self.show_overlay = not self.show_overlay
self.im_composite.set_visible(self.show_overlay)
self._scatter_canvas.draw()
def _on_select(self, verts):
"""Handle lasso selection"""
path = mplPath(verts)
points = self.scatter.get_offsets()
mask = path.contains_points(points)
selected = np.sort(np.array(np.where(mask)[0]))
self.update_selection(selected)
def _build_scatter_plot(self):
self.scatter = self._scatter_ax.scatter(
self.embeddings[:, 0], self.embeddings[:, 1],
alpha=self.alpha_points,
s=self.size_points,
c="gray",
picker=True,
zorder=1,
)
self.current_scatter = self._scatter_ax.scatter(
self.embeddings[self._get_current_idx(), 0],
self.embeddings[self._get_current_idx(), 1],
alpha=1.0,
s=self.size_points * 4,
c="blue",
zorder=2,
)
# Create composite overlay if images provided
if self.idx_images_overlay is not None:
# Show the overlay
self.im_composite = self._scatter_ax.imshow(
self.composite_overlay,
extent=[
self.data_limits[0][0],
self.data_limits[1][0],
self.data_limits[0][1],
self.data_limits[1][1]
],
aspect='auto',
zorder=1000,
)
self._scatter_ax.set_xlim(self.data_limits[0][0], self.data_limits[1][0])
self._scatter_ax.set_ylim(self.data_limits[0][1], self.data_limits[1][1])
self._scatter_ax.set_xticks([])
self._scatter_ax.set_yticks([])
self._scatter_ax.text(
self.data_limits[0][0],
self.data_limits[1][1],
"Draw lasso to select points for labeling",
color="black",
fontsize=12,
ha="left",
va="top"
)
self._scatter_ax.text(
self.data_limits[0][0],
self.data_limits[0][1],
"Press Control-Shift-T to toggle overlay",
color="black",
fontsize=12,
ha="left",
va="bottom"
)
# Add selector for picking points
self.lasso = LassoSelector(self._scatter_ax, onselect=self._on_select, button=1)
def _create_composite_overlay(self):
"""Create a single composite image with all overlays"""
if self.size_images_overlay is None:
# Calculate optimal size based on nearest neighbors
from sklearn.neighbors import NearestNeighbors
nn = NearestNeighbors(n_neighbors=2).fit(self.embeddings[self.idx_images_overlay])
distances = nn.kneighbors(self.embeddings[self.idx_images_overlay])[0]
min_dist = np.min(distances[:, 1])
self.size_images_overlay = min_dist * (1 + self.frac_overlap_allowed)
min_emb = np.nanmin(self.embeddings, axis=0) ## shape (2,)
max_emb = np.nanmax(self.embeddings, axis=0) ## shape (2,)
range_emb = max_emb - min_emb ## shape (2,)
aspect_ratio_ims = (range_emb[1] / range_emb[0]) ## shape (1,)
assert isinstance(self.size_images_overlay, (int, float, np.ndarray)), 'size_images_overlay must be an int, float, or shape (2,) numpy array'
if isinstance(self.size_images_overlay, (int, float)):
self.size_images_overlay = np.array([self.size_images_overlay / aspect_ratio_ims, self.size_images_overlay])
assert self.size_images_overlay.shape == (2,), 'size_images_overlay must be an int, float, or shape (2,) numpy array'
# Create empty canvas
iors = self.image_overlay_raster_size
canvas = np.zeros((*iors, 4)) # RGBA
# Create interpolators for mapping data coordinates to pixel coordinates
interp_x = scipy.interpolate.interp1d(
[self.data_limits[0][0], self.data_limits[1][0]],
[0, iors[0]]
)
interp_y = scipy.interpolate.interp1d(
[self.data_limits[0][1], self.data_limits[1][1]],
[0, iors[1]]
)
# Calculate size of each image in pixels
range_x = self.data_limits[1][0] - self.data_limits[0][0]
range_y = self.data_limits[1][1] - self.data_limits[0][1]
size_x = int((self.size_images_overlay[0] / range_x) * iors[0])
size_y = int((self.size_images_overlay[1] / range_y) * iors[1])
xwidth = self.images.shape[2]
ywidth = self.images.shape[1]
crop_value = min(1.0, self.crop_images_overlay)
crop_value = max(0.1, crop_value)
x_crop_points = int((xwidth - crop_value * xwidth)/2)
y_crop_points = int((ywidth - crop_value * ywidth)/2)
for idx in self.idx_images_overlay:
# Normalize and convert to RGB if grayscale
img = self.images[idx][x_crop_points:-x_crop_points, y_crop_points:-y_crop_points]
if img.ndim == 2:
img = (img - img.min()) / (img.max() - img.min())
img = np.stack([img] * 3, axis=-1)
elif img.ndim == 3:
img = (img - img.min()) / (img.max() - img.min())
# Resize image
coords = np.stack(np.meshgrid(
np.linspace(0, img.shape[0], size_x),
np.linspace(0, img.shape[1], size_y)
), axis=-1)
img_resized = scipy.interpolate.interpn(
(np.arange(img.shape[0]), np.arange(img.shape[1])),
img,
coords,
method='linear',
bounds_error=False,
fill_value=0
)
# Calculate position
x = int(interp_x(self.embeddings[idx, 0]))
y = int(interp_y(self.embeddings[idx, 1]))
# Calculate bounds
x1 = max(0, x - size_x // 2)
x2 = min(iors[0], x + size_x // 2)
y1 = max(0, y - size_y // 2)
y2 = min(iors[1], y + size_y // 2)
# Add to canvas
canvas[y1:y2, x1:x2, :3] = img_resized[:y2-y1, :x2-x1]
canvas[y1:y2, x1:x2, 3] = 1.0 # Alpha channel
self.composite_overlay = np.flipud(canvas) # Flip because imshow origin is bottom left
def _calculate_data_limits(self):
"""Calculate data limits with padding"""
pad = 0.07
min_vals = np.min(self.embeddings, axis=0)
max_vals = np.max(self.embeddings, axis=0)
range_vals = max_vals - min_vals
return (
min_vals - range_vals * pad,
max_vals + range_vals * pad
)
[docs]
def export_svg_hv_bokeh(
obj: object,
path_save: str
) -> None:
"""
Saves a scatterplot from holoviews as an SVG file.
RH 2023
Args:
obj (object):
Holoviews plot object.
path_save (str):
Path to save the SVG file.
"""
import holoviews as hv
import bokeh
plot_state = hv.renderer('bokeh').get_plot(obj).state
plot_state.output_backend = 'svg'
bokeh.io.export_svgs(plot_state, filename=path_save)
######################################################################################################################################
######################################################## H5 HANDLING #################################################################
######################################################################################################################################
## below is actually 'simple_load' from h5_handling
[docs]
def h5_load(
filepath: Union[str, Path],
return_dict: bool = True,
verbose: bool = False
) -> Union[dict, object]:
"""
Returns a dictionary or an H5PY object from a given HDF file.
RH 2023
Args:
filepath (Union[str, Path]):
Full pathname of the file to read.
return_dict (bool):
Whether or not to return a dict object. (Default is ``True``). \n
* ``True``: a dict object is returned.
* ``False``: an H5PY object is returned.
verbose (bool):
Whether to print detailed information during the execution. (Default
is ``False``)
Returns:
(Union[dict, object]):
result (Union[dict, object]):
Either a dictionary containing the groups as keys and the
datasets as values from the HDF file or an H5PY object,
depending on the ``return_dict`` parameter.
"""
import h5py
if return_dict:
with h5py.File(filepath, 'r') as h5_file:
if verbose:
print(f'==== Loading h5 file with hierarchy: ====')
show_item_tree(h5_file)
result = {}
def visitor_func(name, node):
# Split name by '/' and reduce to nested dict
keys = name.split('/')
sub_dict = result
for key in keys[:-1]:
sub_dict = sub_dict.setdefault(key, {})
if isinstance(node, h5py.Dataset):
sub_dict[keys[-1]] = node[...]
elif isinstance(node, h5py.Group):
sub_dict.setdefault(keys[-1], {})
h5_file.visititems(visitor_func)
return result
else:
return h5py.File(filepath, 'r')
[docs]
def show_item_tree(
hObj: Optional[Union[object, dict]] = None,
path: Optional[Union[str, Path]] = None,
depth: Optional[int] = None,
show_metadata: bool = True,
print_metadata: bool = False,
indent_level: int = 0
) -> None:
'''
Recursively displays all the items and groups in an HDF5 object or Python dictionary.
RH 2021
Args:
hObj (Optional[Union[object, dict]]):
Hierarchical object, which can be an HDF5 object or a Python
dictionary. (Default is ``None``)
path (Optional[Union[str, Path]]):
If not ``None``, then the path to the HDF5 object is used instead of
``hObj``. (Default is ``None``)
depth (Optional[int]):
How many levels deep to show the tree. (Default is ``None`` which
shows all levels)
show_metadata (bool):
Whether or not to list metadata with items. (Default is ``True``)
print_metadata (bool):
Whether or not to show values of metadata items. (Default is
``False``)
indent_level (int):
Used internally to the function. User should leave this as the
default. (Default is *0*)
Example:
.. highlight:: python
.. code-block:: python
import h5py
with h5py.File('test.h5', 'r') as f:
show_item_tree(f)
'''
import h5py
if depth is None:
depth = int(10000000000000000000)
else:
depth = int(depth)
if depth < 0:
return
if path is not None:
with h5py.File(path , 'r') as f:
show_item_tree(hObj=f, path=None, depth=depth-1, show_metadata=show_metadata, print_metadata=print_metadata, indent_level=indent_level)
else:
indent = f' '*indent_level
if hasattr(hObj, 'attrs') and show_metadata:
for ii,val in enumerate(list(hObj.attrs.keys()) ):
if print_metadata:
print(f'{indent}METADATA: {val}: {hObj.attrs[val]}')
else:
print(f'{indent}METADATA: {val}: shape={hObj.attrs[val].shape} , dtype={hObj.attrs[val].dtype}')
for ii,val in enumerate(list(iter(hObj))):
if isinstance(hObj[val], h5py.Group):
print(f'{indent}{ii+1}. {val}:----------------')
show_item_tree(hObj[val], depth=depth-1, show_metadata=show_metadata, print_metadata=print_metadata , indent_level=indent_level+1)
elif isinstance(hObj[val], dict):
print(f'{indent}{ii+1}. {val}:----------------')
show_item_tree(hObj[val], depth=depth-1, show_metadata=show_metadata, print_metadata=print_metadata , indent_level=indent_level+1)
else:
if hasattr(hObj[val], 'shape') and hasattr(hObj[val], 'dtype'):
print(f'{indent}{ii+1}. {val}: '.ljust(20) + f'shape={hObj[val].shape} ,'.ljust(20) + f'dtype={hObj[val].dtype}')
else:
print(f'{indent}{ii+1}. {val}: '.ljust(20) + f'type={type(hObj[val])}')
######################################################################################################################################
####################################################### DECOMPOSITION ################################################################
######################################################################################################################################
[docs]
def torch_pca(
X_in: Union[torch.Tensor, np.ndarray],
device: str = 'cpu',
mean_sub: bool = True,
zscore: bool = False,
rank: Optional[int] = None,
return_cpu: bool = True,
return_numpy: bool = False,
) -> Tuple[Union[torch.Tensor, np.ndarray], Union[torch.Tensor, np.ndarray], Union[torch.Tensor, np.ndarray], Union[torch.Tensor, np.ndarray]]:
"""
Conducts Principal Components Analysis using the Pytorch library. This
function can run on either CPU or GPU devices.
RH 2021
Args:
X_in (Union[torch.Tensor, np.ndarray]):
The data to be decomposed. This should be a 2-D array, with columns
representing features and rows representing samples. PCA is
performed column-wise.
device (str):
The device to use for computation, e.g., 'cuda' or 'cpu'. (Default
is ``'cpu'``)
mean_sub (bool):
If ``True``, subtract the mean ('center') from the columns. (Default
is ``True``)
zscore (bool):
If ``True``, z-score the columns. This is equivalent to conducting
PCA on the correlation-matrix. (Default is ``False``)
rank (int):
Maximum estimated rank of the decomposition. If ``None``, then the
rank is assumed to be X.shape[1]. (Default is ``None``)
return_cpu (bool):
(Default is ``True``) \n
* ``True``, all outputs are forced to be on the 'cpu' device.
* ``False``, and device is not 'cpu', then the returns will be on the
provided device.
return_numpy (bool):
If ``True``, all outputs are forced to be of type numpy.ndarray.
(Default is ``False``)
Returns:
(tuple): tuple containing:
components (torch.Tensor or np.ndarray):
The components of the decomposition, represented as a 2-D array.
Each column is a component vector and each row is a feature
weight.
scores (torch.Tensor or np.ndarray):
The scores of the decomposition, represented as a 2-D array.
Each column is a score vector and each row is a sample weight.
singVals (torch.Tensor or np.ndarray):
The singular values of the decomposition, represented as a 1-D
array. Each element is a singular value.
EVR (torch.Tensor or np.ndarray):
The explained variance ratio of each component, represented as a
1-D array. Each element is the explained variance ratio of the
corresponding component.
Example:
.. highlight:: python
.. code-block:: python
components, scores, singVals, EVR = torch_pca(X_in)
"""
if isinstance(X_in, torch.Tensor) == False:
X = torch.from_numpy(X_in).to(device)
elif X_in.device != device:
X = X_in.to(device)
else:
X = copy.copy(X_in)
if mean_sub and not zscore:
X = X - torch.mean(X, dim=0)
if zscore:
X = X - torch.mean(X, dim=0)
stds = torch.std(X, dim=0)
X = X / stds[None,:]
if rank is None:
rank = min(list(X.shape))
(U,S,V) = torch.pca_lowrank(X, q=rank, center=False, niter=2)
components = V
scores = torch.matmul(X, V[:, :rank])
singVals = (S**2)/(len(S)-1)
EVR = (singVals) / torch.sum(singVals)
if return_cpu:
components = components.cpu()
scores = scores.cpu()
singVals = singVals.cpu()
EVR = EVR.cpu()
if return_numpy:
components = components.cpu().numpy()
scores = scores.cpu().numpy()
singVals = singVals.cpu().numpy()
EVR = EVR.cpu().numpy()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
return components, scores, singVals, EVR
######################################################################################################################################
############################################################ VIDEO ###################################################################
######################################################################################################################################
[docs]
def grayscale_to_rgb(
array: Union[np.ndarray, torch.Tensor, List]
) -> Union[np.ndarray, torch.Tensor]:
"""
Converts a grayscale image (2D array) or movie (3D array) to RGB (3D or 4D
array).
RH 2023
Args:
array (Union[np.ndarray, torch.Tensor, list]):
The 2D or 3D array of grayscale images.
Returns:
(Union[np.ndarray, torch.Tensor]):
array (Union[np.ndarray, torch.Tensor]):
The converted 3D or 4D array of RGB images.
"""
if isinstance(array, list):
if isinstance(array[0], np.ndarray):
array = np.stack(array, axis=0)
elif isinstance(array[0], torch.Tensor):
array = torch.stack(array, axis=0)
else:
raise Exception(f'Failed to convert list of type {type(array[0])} to array')
if isinstance(array, np.ndarray):
return np.stack([array, array, array], axis=-1)
elif isinstance(array, torch.Tensor):
return torch.stack([array, array, array], dim=-1)
[docs]
def save_gif(
array: Union[np.ndarray, List],
path: str,
frameRate: float = 5.0,
loop: int = 0,
# backend='PIL',
kwargs_backend: Dict = {},
):
"""
Saves an array of images as a gif.
RH 2023
Args:
array (Union[np.ndarray, list]):
The 3D (grayscale) or 4D (color) array of images. \n
* If dtype is ``float`` type, then scale is from 0 to 1.
* If dtype is ``int``, then scale is from 0 to 255.
path (str):
The path where the gif is saved.
frameRate (float):
The frame rate of the gif. (Default is ``5.0``)
loop (int):
The number of times to loop the gif. (Default is ``0``) \n
* 0 means loop forever
* 1 means play once
* 2 means play twice (loop once)
* etc.
# backend (str):
# Which backend to use.
# Options: 'imageio' or 'PIL'
kwargs_backend (Dict):
The keyword arguments for the backend.
"""
array = np.stack(array, axis=0) if isinstance(array, list) else array
array = grayscale_to_rgb(array) if array.ndim == 3 else array
if np.issubdtype(array.dtype, np.floating):
array = (array*255).astype('uint8')
kwargs_backend.update({'loop': loop} if loop != 1 else {})
# if backend == 'imageio':
# import imageio
# imageio.mimsave(
# path,
# array,
# format='GIF',
# duration=1000/frameRate,
# **kwargs_backend,
# )
# elif backend == 'PIL':
from PIL import Image
frames = [Image.fromarray(array[i_frame]) for i_frame in range(array.shape[0])]
Path(path).parent.mkdir(parents=True, exist_ok=True)
frames[0].save(
path,
format='GIF',
append_images=frames[1:],
save_all=True,
duration=1000/frameRate,
**kwargs_backend,
)
# else:
# raise Exception(f'Unsupported backend {backend}')
######################################################################################################################################
###################################################### IMAGE_PROCESSING ##############################################################
######################################################################################################################################
[docs]
def mask_image_border(
im: np.ndarray,
border_outer: Optional[Union[int, Tuple[int, int, int, int]]] = None,
border_inner: Optional[int] = None,
mask_value: float = 0,
) -> np.ndarray:
"""
Masks an image within specified outer and inner borders.
RH 2022
Args:
im (np.ndarray):
Input image of shape: *(height, width)*.
border_outer (Union[int, tuple[int, int, int, int], None]):
Number of pixels along the border to mask. If ``None``, the border
is not masked. If an int is provided, all borders are equally
masked. If a tuple of ints is provided, borders are masked in the
order: *(top, bottom, left, right)*. (Default is ``None``)
border_inner (int, Optional):
Number of pixels in the center to mask. Will be a square with side
length equal to this value. (Default is ``None``)
mask_value (float):
Value to replace the masked pixels with. (Default is *0*)
Returns:
(np.ndarray):
im_out (np.ndarray):
Masked output image.
"""
## Find the center of the image
height, width = im.shape
center_y = cy = int(np.floor(height/2))
center_x = cx = int(np.floor(width/2))
## Mask the center
if border_inner is not None:
## make edge_lengths
center_edge_length = cel = int(np.ceil(border_inner/2)) if border_inner is not None else 0
im[cy-cel:cy+cel, cx-cel:cx+cel] = mask_value
## Mask the border
if border_outer is not None:
## make edge_lengths
if isinstance(border_outer, int):
border_outer = (border_outer, border_outer, border_outer, border_outer)
im[:border_outer[0], :] = mask_value
im[-border_outer[1]:, :] = mask_value
im[:, :border_outer[2]] = mask_value
im[:, -border_outer[3]:] = mask_value
return im
[docs]
def make_Fourier_mask(
frame_shape_y_x: Tuple[int, int] = (512,512),
bandpass_spatialFs_bounds: List[float] = [1/128, 1/3],
order_butter: int = 5,
mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
dtype_fft: object = torch.complex64,
plot_pref: bool = False,
verbose: bool = False,
) -> torch.Tensor:
"""
Generates a Fourier domain mask for phase correlation, primarily used in
BWAIN.
Args:
frame_shape_y_x (Tuple[int, int]):
Shape of the images that will be processed through this function.
(Default is *(512, 512)*)
bandpass_spatialFs_bounds (List[float]):
Specifies the lowcut and highcut in spatial frequency for the
butterworth filter. (Default is *[1/128, 1/3]*)
order_butter (int):
Order of the butterworth filter. (Default is *5*)
mask (Union[np.ndarray, torch.Tensor, None]):
If not ``None``, this mask is used instead of creating a new one.
(Default is ``None``)
dtype_fft (object):
Data type for the Fourier transform, default is ``torch.complex64``.
plot_pref (bool):
If ``True``, the absolute value of the mask is plotted. (Default is
``False``)
verbose (bool):
If ``True``, enables the print statements for debugging. (Default is
``False``)
Returns:
(torch.Tensor):
mask_fft (torch.Tensor):
The generated mask in the Fourier domain.
"""
get_nd_butterworth_filter
bandpass_spatialFs_bounds = list(bandpass_spatialFs_bounds)
bandpass_spatialFs_bounds[0] = max(bandpass_spatialFs_bounds[0], 1e-9)
if (isinstance(mask, (np.ndarray, torch.Tensor))) or ((mask != 'None') and (mask is not None)):
mask = torch.as_tensor(mask, dtype=dtype_fft)
mask = mask / mask.sum()
mask_fftshift = torch.fft.fftshift(mask)
print(f'User provided mask of shape: {mask.shape} was normalized to sum=1, fftshift-ed, and converted to a torch.Tensor')
else:
wfilt_h = get_nd_butterworth_filter(
shape=frame_shape_y_x,
factor=bandpass_spatialFs_bounds[0],
order=order_butter,
high_pass=True,
real=False,
)
wfilt_l = get_nd_butterworth_filter(
shape=frame_shape_y_x,
factor=bandpass_spatialFs_bounds[1],
order=order_butter,
high_pass=False,
real=False,
)
kernel = torch.as_tensor(
wfilt_h * wfilt_l,
dtype=dtype_fft,
)
mask = kernel / kernel.sum()
# self.mask_fftshift = torch.fft.fftshift(self.mask)
mask_fftshift = mask
mask_fftshift = mask_fftshift.contiguous()
if plot_pref and plot_pref!='False':
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(
torch.abs(kernel.cpu()).numpy(),
# clim=[0,1],
)
if verbose:
print(f'Created Fourier domain mask. self.mask_fftshift.shape: {mask_fftshift.shape}. Images input to find_translation_shifts will now be masked in the FFT domain.')
return mask_fftshift
## Fixed the indexing because the skimage nerds drink too much coffee. RH.
[docs]
def get_nd_butterworth_filter(
shape: Tuple[int, ...],
factor: float,
order: float,
high_pass: bool,
real: bool,
dtype: np.dtype = np.float64,
squared_butterworth: bool = True
) -> np.ndarray:
"""
Creates an N-dimensional Butterworth mask for an FFT.
Args:
shape (Tuple[int, ...]):
Shape of the n-dimensional FFT and mask.
factor (float):
Fraction of mask dimensions where the cutoff should be.
order (float):
Controls the slope in the cutoff region.
high_pass (bool):
Whether the filter is high pass (low frequencies attenuated) or low
pass (high frequencies are attenuated).
real (bool):
Whether the FFT is of a real (``True``) or complex (``False``)
image.
dtype (np.dtype):
The desired output data type of the Butterworth filter. (Default is
``np.float64``)
squared_butterworth (bool):
If ``True``, the square of the Butterworth filter is used. (Default
is ``True``)
Returns:
(np.ndarray):
wfilt (np.ndarray):
The FFT mask.
"""
import functools
ranges = []
for i, d in enumerate(shape):
# start and stop ensures center of mask aligns with center of FFT
# axis = np.arange(-(d - 1) // 2, (d - 1) // 2 + 1) / (d * factor)
axis = np.arange(-(d - 1) / 2, (d - 1) / 2 + 0.5) / (d * factor) ## FIXED, RH 2023
ranges.append(scipy.fft.ifftshift(axis ** 2))
# for real image FFT, halve the last axis
if real:
limit = d // 2 + 1
ranges[-1] = ranges[-1][:limit]
# q2 = squared Euclidean distance grid
q2 = functools.reduce(
np.add, np.meshgrid(*ranges, indexing="ij", sparse=True)
)
q2 = q2.astype(dtype)
q2 = np.power(q2, order)
wfilt = 1 / (1 + q2)
if high_pass:
wfilt *= q2
if not squared_butterworth:
np.sqrt(wfilt, out=wfilt)
return wfilt
[docs]
def warp_matrix_to_remappingIdx(
warp_matrix: Union[np.ndarray, torch.Tensor],
x: int,
y: int
) -> Union[np.ndarray, torch.Tensor]:
"""
Convert a warp matrix (2x3 or 3x3) into remapping indices (2D).
RH 2023
Args:
warp_matrix (Union[np.ndarray, torch.Tensor]):
Warp matrix of shape *(2, 3)* for affine transformations, and *(3,
3)* for homography.
x (int):
Width of the desired remapping indices.
y (int):
Height of the desired remapping indices.
Returns:
(Union[np.ndarray, torch.Tensor]):
remapIdx (Union[np.ndarray, torch.Tensor]):
Remapping indices of shape *(x, y, 2)* representing the x and y
displacements in pixels.
"""
assert warp_matrix.shape in [(2, 3), (3, 3)], f"warp_matrix.shape {warp_matrix.shape} not recognized. Must be (2, 3) or (3, 3)"
assert isinstance(x, int) and isinstance(y, int), f"x and y must be integers"
assert x > 0 and y > 0, f"x and y must be positive"
if isinstance(warp_matrix, torch.Tensor):
stack, meshgrid, arange, hstack, ones, float32, array = torch.stack, torch.meshgrid, torch.arange, torch.hstack, torch.ones, torch.float32, torch.as_tensor
stack_partial = lambda x: stack(x, dim=0)
elif isinstance(warp_matrix, np.ndarray):
stack, meshgrid, arange, hstack, ones, float32, array = np.stack, np.meshgrid, np.arange, np.hstack, np.ones, np.float32, np.array
stack_partial = lambda x: stack(x, axis=0)
else:
raise ValueError(f"warp_matrix must be a torch.Tensor or np.ndarray")
# create the grid
mesh = stack_partial(meshgrid(arange(x, dtype=float32), arange(y, dtype=float32)))
mesh_coords = hstack((mesh.reshape(2,-1).T, ones((x*y, 1), dtype=float32)))
# warp the grid
mesh_coords_warped = (mesh_coords @ warp_matrix.T)
mesh_coords_warped = mesh_coords_warped[:, :2] / mesh_coords_warped[:, 2:3] if warp_matrix.shape == (3, 3) else mesh_coords_warped ## if homography, divide by z
# reshape the warped grid
remapIdx = mesh_coords_warped.T.reshape(2, y, x)
# permute the axes to (x, y, 2)
remapIdx = remapIdx.permute(1, 2, 0) if isinstance(warp_matrix, torch.Tensor) else remapIdx.transpose(1, 2, 0)
return remapIdx
[docs]
def remap_images(
images: Union[np.ndarray, torch.Tensor],
remappingIdx: Union[np.ndarray, torch.Tensor],
backend: str = "torch",
interpolation_method: str = 'linear',
border_mode: str = 'constant',
border_value: float = 0,
device: str = 'cpu',
) -> Union[np.ndarray, torch.Tensor]:
"""
Applies remapping indices to a set of images. Remapping indices, similar to
flow fields, describe the index of the pixel to sample from rather than the
displacement of each pixel. RH 2023
Args:
images (Union[np.ndarray, torch.Tensor]):
The images to be warped. Shapes can be *(N, C, H, W)*, *(C, H, W)*,
or *(H, W)*.
remappingIdx (Union[np.ndarray, torch.Tensor]):
The remapping indices, describing the index of the pixel to sample
from. Shape is *(H, W, 2)*.
backend (str):
The backend to use. Can be either ``'torch'`` or ``'cv2'``. (Default
is ``'torch'``)
interpolation_method (str):
The interpolation method to use. Options are ``'linear'``,
``'nearest'``, ``'cubic'``, and ``'lanczos'``. Refer to `cv2.remap`
or `torch.nn.functional.grid_sample` for more details. (Default is
``'linear'``)
border_mode (str):
The border mode to use. Options include ``'constant'``,
``'reflect'``, ``'replicate'``, and ``'wrap'``. Refer to `cv2.remap`
for more details. (Default is ``'constant'``)
border_value (float):
The border value to use. Refer to `cv2.remap` for more details.
(Default is ``0``)
device (str):
The device to use for computations. Commonly either ``'cpu'`` or
``'gpu'``. (Default is ``'cpu'``)
Returns:
(Union[np.ndarray, torch.Tensor]):
warped_images (Union[np.ndarray, torch.Tensor]):
The warped images. The shape will be the same as the input
images, which can be *(N, C, H, W)*, *(C, H, W)*, or *(H, W)*.
"""
import cv2
# Check inputs
assert isinstance(images, (np.ndarray, torch.Tensor)), f"images must be a np.ndarray or torch.Tensor"
assert isinstance(remappingIdx, (np.ndarray, torch.Tensor)), f"remappingIdx must be a np.ndarray or torch.Tensor"
if images.ndim == 2:
images = images[None, None, :, :]
elif images.ndim == 3:
images = images[None, :, :, :]
elif images.ndim != 4:
raise ValueError(f"images must be a 2D, 3D, or 4D array. Got shape {images.shape}")
assert remappingIdx.ndim == 3, f"remappingIdx must be a 3D array of shape (H, W, 2). Got shape {remappingIdx.shape}"
# Check backend
if backend not in ["torch", "cv2"]:
raise ValueError("Invalid backend. Supported backends are 'torch' and 'cv2'.")
if backend == 'torch':
if isinstance(images, np.ndarray):
images = torch.as_tensor(images, device=device, dtype=torch.float32)
elif isinstance(images, torch.Tensor):
images = images.to(device=device, dtype=torch.float32)
if isinstance(remappingIdx, np.ndarray):
remappingIdx = torch.as_tensor(remappingIdx, device=device, dtype=torch.float32)
elif isinstance(remappingIdx, torch.Tensor):
remappingIdx = remappingIdx.to(device=device, dtype=torch.float32)
interpolation = {
'linear': 'bilinear',
'nearest': 'nearest',
'cubic': 'bicubic',
'lanczos': 'lanczos',
}[interpolation_method]
border = {
'constant': 'zeros',
'reflect': 'reflection',
'replicate': 'replication',
'wrap': 'circular',
}[border_mode]
## Convert remappingIdx to normalized grid
normgrid = cv2RemappingIdx_to_pytorchFlowField(remappingIdx)
# Apply remappingIdx
warped_images = torch.nn.functional.grid_sample(
images,
normgrid[None,...],
mode=interpolation,
padding_mode=border,
align_corners=True, ## align_corners=True is the default in cv2.remap. See documentation for details.
)
elif backend == 'cv2':
assert isinstance(images, np.ndarray), f"images must be a np.ndarray when using backend='cv2'"
assert isinstance(remappingIdx, np.ndarray), f"remappingIdx must be a np.ndarray when using backend='cv2'"
## convert to float32 if not uint8
images = images.astype(np.float32) if images.dtype != np.uint8 else images
remappingIdx = remappingIdx.astype(np.float32) if remappingIdx.dtype != np.uint8 else remappingIdx
interpolation = {
'linear': cv2.INTER_LINEAR,
'nearest': cv2.INTER_NEAREST,
'cubic': cv2.INTER_CUBIC,
'lanczos': cv2.INTER_LANCZOS4,
}[interpolation_method]
borderMode = {
'constant': cv2.BORDER_CONSTANT,
'reflect': cv2.BORDER_REFLECT,
'replicate': cv2.BORDER_REPLICATE,
'wrap': cv2.BORDER_WRAP,
}[border_mode]
# Apply remappingIdx
def remap(ims):
out = np.stack([cv2.remap(
im,
remappingIdx[..., 0],
remappingIdx[..., 1],
interpolation=interpolation,
borderMode=borderMode,
borderValue=border_value,
) for im in ims], axis=0)
return out
warped_images = np.stack([remap(im) for im in images], axis=0)
return warped_images.squeeze()
[docs]
def remap_sparse_images(
ims_sparse: Union[scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]],
remappingIdx: np.ndarray,
method: str = 'linear',
fill_value: float = 0,
dtype: Union[str, np.dtype] = None,
safe: bool = True,
n_workers: int = -1,
verbose: bool = True,
) -> List[scipy.sparse.csr_array]:
"""
Remaps a list of sparse images using the given remap field.
RH 2023
Args:
ims_sparse (Union[scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]):
A single sparse image or a list of sparse images.
remappingIdx (np.ndarray):
An array of shape *(H, W, 2)* representing the remap field. It
should be the same size as the images in ims_sparse.
method (str):
Interpolation method to use. See ``scipy.interpolate.griddata``.
Options are:
\n
* ``'linear'``
* ``'nearest'``
* ``'cubic'`` \n
(Default is ``'linear'``)
fill_value (float):
Value used to fill points outside the convex hull. (Default is
``0.0``)
dtype (Union[str, np.dtype]):
The data type of the resulting sparse images. Default is ``None``,
which will use the data type of the input sparse images.
safe (bool):
If ``True``, checks if the image is 0D or 1D and applies a tiny
Gaussian blur to increase the image width. (Default is ``True``)
n_workers (int):
Number of parallel workers to use. Default is *-1*, which uses all
available CPU cores.
verbose (bool):
Whether or not to use a tqdm progress bar. (Default is ``True``)
Returns:
(List[scipy.sparse.csr_array]):
ims_sparse_out (List[scipy.sparse.csr_array]):
A list of remapped sparse images.
Raises:
AssertionError: If the image and remappingIdx have different spatial
dimensions.
"""
# Ensure ims_sparse is a list of sparse matrices
ims_sparse = [ims_sparse] if not isinstance(ims_sparse, list) else ims_sparse
# Assert that all images are sparse matrices
assert all(scipy.sparse.issparse(im) for im in ims_sparse), "All images must be sparse matrices."
# Assert and retrieve dimensions
dims_ims = ims_sparse[0].shape
dims_remap = remappingIdx.shape
assert dims_ims == dims_remap[:-1], "Image and remappingIdx should have same spatial dimensions."
dtype = ims_sparse[0].dtype if dtype is None else dtype
if safe:
conv2d = Toeplitz_convolution2d(
x_shape=(dims_ims[0], dims_ims[1]),
k=np.array([[0 , 1e-8, 0 ],
[1e-8, 1, 1e-8],
[0 , 1e-8, 0 ]], dtype=dtype),
dtype=dtype,
)
def warp_sparse_image(
im_sparse: scipy.sparse.csr_array,
remappingIdx: np.ndarray,
method: str = method,
fill_value: float = fill_value,
safe: bool = safe
) -> scipy.sparse.csr_array:
# Convert sparse image to COO format
im_coo = scipy.sparse.coo_array(im_sparse)
# Get coordinates and values from COO format
rows, cols = im_coo.row, im_coo.col
data = im_coo.data
if safe:
# can't use scipy.interpolate.griddata with 1d values
is_horz = np.unique(rows).size == 1
is_vert = np.unique(cols).size == 1
# check for diagonal pixels
# slope = rise / run --- don't need to check if run==0
rdiff = np.diff(rows)
cdiff = np.diff(cols)
is_diag = np.unique(cdiff / rdiff).size == 1 if not np.any(rdiff==0) else False
# best practice to just convolve instead of interpolating if too few pixels
is_smol = rows.size < 3
if is_horz or is_vert or is_smol or is_diag:
# warp convolved sparse image directly without interpolation
return warp_sparse_image(im_sparse=conv2d(im_sparse, batching=False), remappingIdx=remappingIdx)
# Get values at the grid points
try:
grid_values = scipy.interpolate.griddata(
points=(rows, cols),
values=data,
xi=remappingIdx[:,:,::-1],
method=method,
fill_value=fill_value,
)
except Exception as e:
raise Exception(f"Error interpolating sparse image. Something is either weird about one of the input images or the remappingIdx. Error: {e}")
# Create a new sparse image from the nonzero pixels
warped_sparse_image = scipy.sparse.csr_array(grid_values, dtype=dtype)
warped_sparse_image.eliminate_zeros()
return warped_sparse_image
wsi_partial = partial(warp_sparse_image, remappingIdx=remappingIdx)
ims_sparse_out = map_parallel(func=wsi_partial, args=[ims_sparse,], method='multithreading', n_workers=n_workers, prog_bar=verbose)
return ims_sparse_out
[docs]
def invert_remappingIdx(
remappingIdx: np.ndarray,
method: str = 'linear',
fill_value: Optional[float] = np.nan
) -> np.ndarray:
"""
Inverts a remapping index field.
Requires the assumption that the remapping index field is invertible or bijective/one-to-one and non-occluding.
Defined 'remap_AB' as a remapping index field that warps image A onto image B, then 'remap_BA' is the remapping index field that warps image B onto image A. This function computes 'remap_BA' given 'remap_AB'.
RH 2023
Args:
remappingIdx (np.ndarray):
An array of shape *(H, W, 2)* representing the remap field.
method (str):
Interpolation method to use. See ``scipy.interpolate.griddata``. Options are:
\n
* ``'linear'``
* ``'nearest'``
* ``'cubic'`` \n
(Default is ``'linear'``)
fill_value (Optional[float]):
Value used to fill points outside the convex hull.
(Default is ``np.nan``)
Returns:
(np.ndarray):
An array of shape *(H, W, 2)* representing the inverse remap field.
"""
H, W, _ = remappingIdx.shape
# Create the meshgrid of the original image
grid = np.mgrid[:H, :W][::-1].transpose(1,2,0).reshape(-1, 2)
# Flatten the original meshgrid and remappingIdx
remapIdx_flat = remappingIdx.reshape(-1, 2)
# Interpolate the inverse mapping using griddata
map_BA = scipy.interpolate.griddata(
points=remapIdx_flat,
values=grid,
xi=grid,
method=method,
fill_value=fill_value,
).reshape(H,W,2)
return map_BA
[docs]
def invert_warp_matrix(
warp_matrix: np.ndarray
) -> np.ndarray:
"""
Inverts a provided warp matrix for the transformation A->B to compute the
warp matrix for B->A.
RH 2023
Args:
warp_matrix (np.ndarray):
A 2x3 or 3x3 array representing the warp matrix. Shape: *(2, 3)* or
*(3, 3)*.
Returns:
(np.ndarray):
inverted_warp_matrix (np.ndarray):
The inverted warp matrix. Shape: same as input.
"""
if warp_matrix.shape == (2, 3):
# Convert 2x3 affine warp matrix to 3x3 by appending [0, 0, 1] as the last row
warp_matrix_3x3 = np.vstack((warp_matrix, np.array([0, 0, 1])))
elif warp_matrix.shape == (3, 3):
warp_matrix_3x3 = warp_matrix
else:
raise ValueError("Input warp_matrix must be of shape (2, 3) or (3, 3)")
# Compute the inverse of the 3x3 warp matrix
inverted_warp_matrix_3x3 = np.linalg.inv(warp_matrix_3x3)
if warp_matrix.shape == (2, 3):
# Convert the inverted 3x3 warp matrix back to 2x3 by removing the last row
inverted_warp_matrix = inverted_warp_matrix_3x3[:2, :]
else:
inverted_warp_matrix = inverted_warp_matrix_3x3
return inverted_warp_matrix
[docs]
def compose_remappingIdx(
remap_AB: np.ndarray,
remap_BC: np.ndarray,
method: str = 'linear',
fill_value: Optional[float] = np.nan,
bounds_error: bool = False,
) -> np.ndarray:
"""
Composes two remapping index fields using scipy.interpolate.interpn.
This function computes 'remap_AC' from 'remap_AB' and 'remap_BC', where
'remap_AB' is a remapping index field that warps image A onto image B, and
'remap_BC' is a remapping index field that warps image B onto image C.
RH 2023
Args:
remap_AB (np.ndarray):
An array of shape *(H, W, 2)* representing the remap field from
image A to image B.
remap_BC (np.ndarray):
An array of shape *(H, W, 2)* representing the remap field from
image B to image C.
method (str):
Interpolation method to use. Either \n
* ``'linear'``: Use linear interpolation (default).
* ``'nearest'``: Use nearest interpolation.
* ``'cubic'``: Use cubic interpolation.
fill_value (Optional[float]):
The value used for points outside the interpolation domain. (Default
is ``np.nan``)
bounds_error (bool):
If ``True``, a ValueError is raised when interpolated values are
requested outside of the domain of the input data. (Default is
``False``)
Returns:
(np.ndarray):
remap_AC (np.ndarray):
An array of shape *(H, W, 2)* representing the remap field from
image A to image C.
"""
# Get the shape of the remap fields
H, W, _ = remap_AB.shape
# Combine the x and y components of remap_AB into a complex number
# This is done to simplify the interpolation process
AB_complex = remap_AB[:,:,0] + remap_AB[:,:,1]*1j
# Perform the interpolation using interpn
AC = scipy.interpolate.interpn(
(np.arange(H), np.arange(W)),
AB_complex,
remap_BC.reshape(-1, 2)[:, ::-1],
method=method,
bounds_error=bounds_error,
fill_value=fill_value
).reshape(H, W)
# Split the real and imaginary parts of the interpolated result to get the x and y components
remap_AC = np.stack((AC.real, AC.imag), axis=-1)
return remap_AC
def _make_idx_grid(
im: Union[np.ndarray, object],
) -> Union[np.ndarray, object]:
"""
Helper function to make a grid of indices for an image. Used in
``flowField_to_remappingIdx`` and ``remappingIdx_to_flowField``.
Args:
im (Union[np.ndarray, object]):
An image represented as a numpy ndarray or torch Tensor.
Returns:
(Union[np.ndarray, object]):
idx_grid (Union[np.ndarray, object]):
Index grid for the given image.
"""
if isinstance(im, torch.Tensor):
stack, meshgrid, arange = partial(torch.stack, dim=-1), partial(torch.meshgrid, indexing='xy'), partial(torch.arange, device=im.device, dtype=im.dtype)
elif isinstance(im, np.ndarray):
stack, meshgrid, arange = partial(np.stack, axis=-1), partial(np.meshgrid, indexing='xy'), partial(np.arange, dtype=im.dtype)
return stack(meshgrid(arange(im.shape[1]), arange(im.shape[0]))) # (H, W, 2). Last dimension is (x, y).
[docs]
def flowField_to_remappingIdx(
ff: Union[np.ndarray, object],
) -> Union[np.ndarray, object]:
"""
Convert a flow field to a remapping index. **WARNING**: Technically, it is
not possible to convert a flow field to a remapping index, since the
remapping index describes an interpolation mapping, while the flow field
describes a displacement.
RH 2023
Args:
ff (Union[np.ndarray, object]):
Flow field represented as a numpy ndarray or torch Tensor.
It describes the displacement of each pixel.
Shape *(H, W, 2)*. Last dimension is *(x, y)*.
Returns:
(Union[np.ndarray, object]):
ri (Union[np.ndarray, object]):
Remapping index. It describes the index of the pixel in
the original image that should be mapped to the new pixel.
Shape *(H, W, 2)*.
"""
ri = ff + _make_idx_grid(ff)
return ri
[docs]
def remappingIdx_to_flowField(
ri: Union[np.ndarray, object],
) -> Union[np.ndarray, object]:
"""
Convert a remapping index to a flow field. **WARNING**: Technically, it is
not possible to convert a remapping index to a flow field, since the
remapping index describes an interpolation mapping, while the flow field
describes a displacement.
RH 2023
Args:
ri (Union[np.ndarray, object]):
Remapping index represented as a numpy ndarray or torch Tensor.
It describes the index of the pixel in the original image that
should be mapped to the new pixel. Shape *(H, W, 2)*. Last
dimension is *(x, y)*.
Returns:
(Union[np.ndarray, object]):
ff (Union[np.ndarray, object]):
Flow field. It describes the displacement of each pixel.
Shape *(H, W, 2)*.
"""
ff = ri - _make_idx_grid(ri)
return ff
[docs]
def cv2RemappingIdx_to_pytorchFlowField(
ri: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
"""
Converts remapping indices from the OpenCV format to the PyTorch format. In
the OpenCV format, the displacement is in pixels relative to the top left
pixel of the image. In the PyTorch format, the displacement is in pixels
relative to the center of the image. RH 2023
Args:
ri (Union[np.ndarray, torch.Tensor]):
Remapping indices. Each pixel describes the index of the pixel in
the original image that should be mapped to the new pixel. Shape:
*(H, W, 2)*. The last dimension is (x, y).
Returns:
(Union[np.ndarray, torch.Tensor]):
normgrid (Union[np.ndarray, torch.Tensor]):
"Flow field", in the PyTorch format. Technically not a flow
field, since it doesn't describe displacement. Rather, it is a
remapping index relative to the center of the image. Shape: *(H,
W, 2)*. The last dimension is (x, y).
"""
assert isinstance(ri, torch.Tensor), f"ri must be a torch.Tensor. Got {type(ri)}"
im_shape = torch.flipud(torch.as_tensor(ri.shape[:2], dtype=torch.float32, device=ri.device)) ## (W, H)
normgrid = ((ri / (im_shape[None, None, :] - 1)) - 0.5) * 2 ## PyTorch's grid_sample expects grid values in [-1, 1] because it's a relative offset from the center pixel. CV2's remap expects grid values in [0, 1] because it's an absolute offset from the top-left pixel.
## note also that pytorch's grid_sample expects align_corners=True to correspond to cv2's default behavior.
return normgrid
[docs]
def pytorchFlowField_to_cv2RemappingIdx(
normgrid: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
"""
Converts remapping indices from the PyTorch format to the OpenCV format. In
the OpenCV format, the displacement is in pixels relative to the top left
pixel of the image. In the PyTorch format, the displacement is in pixels
relative to the center of the image. RH 2023
Args:
normgrid (Union[np.ndarray, torch.Tensor]):
"Flow field", in the PyTorch format. Technically not a flow field,
since it doesn't describe displacement. Rather, it is a remapping
index relative to the center of the image. Shape: *(H, W, 2)*. The
last dimension is (x, y).
Returns:
(Union[np.ndarray, torch.Tensor]):
ri (Union[np.ndarray, torch.Tensor]):
Remapping indices. Each pixel describes the index of the pixel
in the original image that should be mapped to the new pixel.
Shape: *(H, W, 2)*. The last dimension is (x, y).
"""
assert isinstance(normgrid, torch.Tensor), f"normgrid must be a torch.Tensor. Got {type(normgrid)}"
im_shape = torch.flipud(torch.as_tensor(normgrid.shape[:2], dtype=torch.float32, device=normgrid.device)) ## (W, H)
ri = ((normgrid / 2) + 0.5) * (im_shape[None, None, :] - 1) ## PyTorch's grid_sample expects grid values in [-1, 1] because it's a relative offset from the center pixel. CV2's remap expects grid values in [0, 1] because it's an absolute offset from the top-left pixel.
return ri
[docs]
def resize_remappingIdx(
ri: Union[np.ndarray, torch.Tensor],
new_shape: Tuple[int, int],
interpolation: str = 'BILINEAR',
) -> Union[np.ndarray, torch.Tensor]:
"""
Resize a remapping index field. This function both resizes the shape of the
actual remappingIdx arrays and scales the values to match the new shape.
RH 2024
Args:
ri (np.ndarray or torch.Tensor):
Remapping index field(s). Describes the index of the pixel in the
original image that should be mapped to the new pixel. Shape (H, W,
2) or (B, H, W, 2). Last dimension is (x, y).
new_shape (Tuple[int, int]):
New shape of the remapping index field.
Shape (H', W').
interpolation (str):
The interpolation method to use. See ``torchvision.transforms.Resize``
for options. \n
* ``'NEAREST'``: Nearest neighbor interpolation
* ``'NEAREST_EXACT'``: Nearest neighbor interpolation
* ``'BILINEAR'``: Bilinear interpolation
* ``'BICUBIC'``: Bicubic interpolation
antialias (bool):
If ``True``, antialiasing will be used. (Default is ``False``)
Returns:
ri_resized (np.ndarray or torch.Tensor):
Resized remapping index field.
Shape (H', W', 2). Last dimension is (x, y).
"""
assert isinstance(ri, (np.ndarray, torch.Tensor)), f"ri must be a np.ndarray or torch.Tensor. Got {type(ri)}"
assert ri.ndim in [3, 4], f"ri must have shape (H, W, 2) or (B, H, W, 2). Got shape {ri.shape}"
assert ri.shape[-1] == 2, f"ri must have shape (H, W, 2). Got shape {ri.shape}"
assert isinstance(new_shape, (tuple, list, np.ndarray, torch.Tensor)), f"new_shape must be a tuple, list, np.ndarray, or torch.Tensor. Got {type(new_shape)}"
assert len(new_shape) == 2, f"new_shape must have length 2. Got length {len(new_shape)}"
new_shape = (int(new_shape[0]), int(new_shape[1]))
if ri.ndim == 3:
ri = ri[None, ...]
return_3D = True
else:
return_3D = False
hw_ri = ri.shape[1:3]
if isinstance(ri, np.ndarray):
ri = torch.as_tensor(ri)
return_numpy = True
else:
return_numpy = False
device = ri.device
offsets = torch.as_tensor([(new_shape[1] - 1) / (hw_ri[1] - 1), (new_shape[0] - 1) / (hw_ri[0] - 1)], dtype=torch.float32, device=device)[None, None, None, ...]
ri_resized = resize_images(
images=ri.permute(3, 0, 1, 2),
new_shape=new_shape,
interpolation=interpolation,
).permute(1, 2, 3, 0) * offsets
if return_numpy:
ri_resized = ri_resized.cpu().numpy()
if return_3D:
ri_resized = ri_resized[0]
return ri_resized
[docs]
def add_text_to_images(
images: np.array,
text: List[List[str]],
position: Tuple[int, int] = (10,10),
font_size: int = 1,
color: Tuple[int, int, int] = (255,255,255),
line_width: int = 1,
font: Optional[str] = None,
frameRate: int = 30
) -> np.array:
"""
Adds text to images using ``cv2.putText()``.
RH 2022
Args:
images (np.array):
Frames of video or images. Shape: *(n_frames, height, width, n_channels)*.
text (list of lists):
Text to add to images.
The outer list: one element per frame.
The inner list: each element is a line of text.
position (tuple):
(x, y) position of the text (top left corner). (Default is *(10,10)*)
font_size (int):
Font size of the text. (Default is *1*)
color (tuple):
(r, g, b) color of the text. (Default is *(255,255,255)*)
line_width (int):
Line width of the text. (Default is *1*)
font (str):
Font to use. If ``None``, then will use ``cv2.FONT_HERSHEY_SIMPLEX``.
See ``cv2.FONT...`` for more options. (Default is ``None``)
frameRate (int):
Frame rate of the video. (Default is *30*)
Returns:
(np.array):
images_with_text (np.array):
Frames of video or images with text added.
"""
import cv2
import copy
if font is None:
font = cv2.FONT_HERSHEY_SIMPLEX
images_cp = copy.deepcopy(images)
for i_f, frame in enumerate(images_cp):
for i_t, t in enumerate(text[i_f]):
fn_putText = lambda frame_gray: cv2.putText(frame_gray, t, [position[0] , position[1] + i_t*font_size*30], font, font_size, color, line_width)
if frame.ndim == 3:
[fn_putText(frame[:,:,ii]) for ii in range(frame.shape[2])]
else:
fn_putText(frame)
return images_cp
[docs]
def resize_images(
images: Union[np.ndarray, List[np.ndarray], torch.Tensor, List[torch.Tensor]],
new_shape: Tuple[int, int] = (100,100),
interpolation: str = 'BILINEAR',
antialias: bool = False,
device: Optional[str] = None,
return_numpy: Optional[bool] = None,
) -> np.ndarray:
"""
Resizes images using the ``torchvision.transforms.Resize`` method.
RH 2023
Args:
images (Union[np.ndarray, List[np.ndarray]], torch.Tensor, List[torch.Tensor]):
Images or frames of a video. Can be 2D, 3D, or 4D.
* For a 2D array: shape is *(height, width)*
* For a 3D array: shape is *(n_frames, height, width)*
* For a 4D array: shape is *(n_frames, n_channels, height, width)*
new_shape (Tuple[int, int]):
The desired height and width of resized images as a tuple.
(Default is *(100, 100)*)
interpolation (str):
The interpolation method to use. See ``torchvision.transforms.Resize``
for options.
* ``'NEAREST'``: Nearest neighbor interpolation
* ``'NEAREST_EXACT'``: Nearest neighbor interpolation
* ``'BILINEAR'``: Bilinear interpolation
* ``'BICUBIC'``: Bicubic interpolation
antialias (bool):
If ``True``, antialiasing will be used. (Default is ``False``)
device Optional[str]:
The device to use for ``torchvision.transforms.Resize``. If None,
will use the device of the input images. (Default is ``None``)
return_numpy Optional[bool]:
If ``True``, then will return a numpy array. Otherwise, will return
a torch tensor on the defined device. If None, will return a numpy
array only if the input is a numpy array. (Default is ``None``)
Returns:
(np.ndarray):
images_resized (np.ndarray):
Frames of video or images with overlay added.
"""
## Convert images to torch tensor
if isinstance(images, list):
if isinstance(images[0], np.ndarray):
device = device if device is not None else 'cpu'
images = torch.stack([torch.as_tensor(im, device=device) for im in images], dim=0)
return_numpy = True if return_numpy is None else return_numpy
elif isinstance(images, np.ndarray):
device = device if device is not None else 'cpu'
images = torch.as_tensor(images, device=device)
return_numpy = True if return_numpy is None else return_numpy
elif isinstance(images, torch.Tensor):
images = images.to(device=device)
else:
raise ValueError(f"images must be a np.ndarray or torch.Tensor or a list of np.ndarray or torch.Tensor. Got {type(images)}")
## Convert images to 4D
def pad_to_4D(ims):
if ims.ndim == 2:
ims = ims[None, None, :, :]
elif ims.ndim == 3:
ims = ims[None, :, :, :]
elif ims.ndim != 4:
raise ValueError(f"images must be a 2D, 3D, or 4D array. Got shape {ims.shape}")
return ims
ndim_orig = images.ndim
images = pad_to_4D(images)
## Get interpolation method
try:
interpolation = getattr(torchvision.transforms.InterpolationMode, interpolation.upper())
except Exception as e:
raise Exception(f"Invalid interpolation method. See torchvision.transforms.InterpolationMode for options. Error: {e}")
resizer = torchvision.transforms.Resize(
size=new_shape,
interpolation=interpolation,
antialias=antialias,
).to(device=device)
images_resized = resizer(images)
## Convert images back to original shape
def unpad_to_orig(ims, ndim_orig):
if ndim_orig == 2:
ims = ims[0,0,:,:]
elif ndim_orig == 3:
ims = ims[0,:,:,:]
elif ndim_orig != 4:
raise ValueError(f"images must be a 2D, 3D, or 4D array. Got shape {ims.shape}")
return ims
images_resized = unpad_to_orig(images_resized, ndim_orig)
## Convert images to numpy
if return_numpy == True:
images_resized = images_resized.detach().cpu().numpy()
return images_resized
[docs]
class ImageAlignmentChecker:
"""
Class to check the alignment of images using phase correlation.
RH 2024
Args:
hw (Tuple[int, int]):
Height and width of the images.
radius_in (Union[float, Tuple[float, float]]):
Radius of the pixel shift / offset that can be considered as
'aligned'. Used to create the 'in' filter which is an image of a
small centered circle that is used as a filter and multiplied by
the phase correlation images. If a single value is provided, the
filter will be a circle with radius 0 to that value; it will be
converted to a tuple representing a bandpass filter (0, radius_in).
radius_out (Union[float, Tuple[float, float]]):
Similar to radius_in, but for the 'out' filter, which defines the
'null distribution' for defining what is 'aligned'. Should be a
value larger than the expected maximum pixel shift / offset. If a
single value is provided, the filter will be a donut / taurus
starting at that value and ending at the edge of the smallest
dimension of the image; it will be converted to a tuple representing
a bandpass filter (radius_out, min(hw)).
order (int):
Order of the butterworth bandpass filters used to define the 'in'
and 'out' filters. Larger values will result in a sharper edges, but
values higher than 5 can lead to collapse of the filter.
device (str):
Torch device to use for computations. (Default is 'cpu')
Attributes:
hw (Tuple[int, int]):
Height and width of the images.
order (int):
Order of the butterworth bandpass filters used to define the 'in'
and 'out' filters.
device (str):
Torch device to use for computations.
filt_in (torch.Tensor):
The 'in' filter used for scoring the alignment.
filt_out (torch.Tensor):
The 'out' filter used for scoring the alignment.
"""
def __init__(
self,
hw: Tuple[int, int],
radius_in: Union[float, Tuple[float, float]],
radius_out: Union[float, Tuple[float, float]],
order: int = 5,
device: str = 'cpu',
):
## Set attributes
### Convert to torch.Tensor
self.hw = tuple(hw)
### Set other attributes
self.order = int(order)
self.device = str(device)
### Set filter attributes
if isinstance(radius_in, (int, float, complex)):
radius_in = (float(0.0), float(radius_in))
elif isinstance(radius_in, (tuple, list, np.ndarray, torch.Tensor)):
radius_in = tuple(float(r) for r in radius_in)
else:
raise ValueError(f'radius_in must be a float or tuple of floats. Found type: {type(radius_in)}')
if isinstance(radius_out, (int, float, complex)):
radius_out = (float(radius_out), float(min(self.hw)) / 2)
elif isinstance(radius_out, (tuple, list, np.ndarray, torch.Tensor)):
radius_out = tuple(float(r) for r in radius_out)
else:
raise ValueError(f'radius_out must be a float or tuple of floats. Found type: {type(radius_out)}')
## Make filters
self.filt_in, self.filt_out = (torch.as_tensor(make_2D_frequency_filter(
hw=self.hw,
low=bp[0],
high=bp[1],
order=order,
), dtype=torch.float32, device=device) for bp in [radius_in, radius_out])
[docs]
def score_alignment(
self,
images: Union[np.ndarray, torch.Tensor],
images_ref: Optional[Union[np.ndarray, torch.Tensor]] = None,
):
"""
Score the alignment of a set of images using phase correlation. Computes
the stats of the center ('in') of the phase correlation image over the
stats of the outer region ('out') of the phase correlation image.
RH 2024
Args:
images (Union[np.ndarray, torch.Tensor]):
A 3D array of images. Shape: *(n_images, height, width)*
images_ref (Optional[Union[np.ndarray, torch.Tensor]]):
Reference images to compare against. If provided, the images
will be compared against these images. If not provided, the
images will be compared against themselves. (Default is
``None``)
Returns:
(Dict):
Dictionary containing the following keys:
* 'mean_out':
Mean of the phase correlation image weighted by the
'out' filter
* 'mean_in':
Mean of the phase correlation image weighted by the
'in' filter
* 'ptile95_out':
95th percentile of the phase correlation image multiplied by
the 'out' filter
* 'max_in':
Maximum value of the phase correlation image multiplied by
the 'in' filter
* 'std_out':
Standard deviation of the phase correlation image weighted by
the 'out' filter
* 'std_in':
Standard deviation of the phase correlation image weighted by
the 'in' filter
* 'max_diff':
Difference between the 'max_in' and 'ptile95_out' values
* 'z_in':
max_diff divided by the 'std_out' value
* 'r_in':
max_diff divided by the 'ptile95_out' value
"""
def _fix_images(ims):
assert isinstance(ims, (np.ndarray, torch.Tensor, list, tuple)), f'images must be np.ndarray, torch.Tensor, or a list/tuple of np.ndarray or torch.Tensor. Found type: {type(ims)}'
if isinstance(ims, (list, tuple)):
assert all(isinstance(im, (np.ndarray, torch.Tensor)) for im in ims), f'images must be np.ndarray or torch.Tensor. Found types: {set(type(im) for im in ims)}'
assert all(im.ndim == 2 for im in ims), f'images must be 2D arrays (height, width). Found shapes: {set(im.shape for im in ims)}'
if isinstance(ims[0], np.ndarray):
ims = np.stack([np.array(im) for im in ims], axis=0)
else:
ims = torch.stack([torch.as_tensor(im) for im in ims], dim=0)
else:
if ims.ndim == 2:
ims = ims[None, :, :]
assert ims.ndim == 3, f'images must be a 3D array (n_images, height, width). Found shape: {ims.shape}'
assert ims.shape[1:] == self.hw, f'images must have shape (n_images, {self.hw[0]}, {self.hw[1]}). Found shape: {ims.shape}'
ims = torch.as_tensor(ims, dtype=torch.float32, device=self.device)
return ims
images = _fix_images(images)
images_ref = _fix_images(images_ref) if images_ref is not None else images
pc = phase_correlation(images_ref[None, :, :, :], images[:, None, :, :]) ## All to all phase correlation. Shape: (n_images, n_images, height, width)
## metrics
filt_in, filt_out = self.filt_in[None, None, :, :], self.filt_out[None, None, :, :]
mean_out = (pc * filt_out).sum(dim=(-2, -1)) / filt_out.sum(dim=(-2, -1))
mean_in = (pc * filt_in).sum(dim=(-2, -1)) / filt_in.sum(dim=(-2, -1))
ptile95_out = torch.quantile((pc * filt_out).reshape(pc.shape[0], pc.shape[1], -1)[:, :, filt_out.reshape(-1) > 1e-3], 0.95, dim=-1)
max_in = (pc * filt_in).amax(dim=(-2, -1))
std_out = torch.sqrt(torch.mean((pc - mean_out[:, :, None, None])**2 * filt_out, dim=(-2, -1)))
std_in = torch.sqrt(torch.mean((pc - mean_in[:, :, None, None])**2 * filt_in, dim=(-2, -1)))
max_diff = max_in - ptile95_out
z_in = max_diff / std_out
r_in = max_diff / ptile95_out
outs = {
'pc': pc.cpu().numpy(),
'mean_out': mean_out,
'mean_in': mean_in,
'ptile95_out': ptile95_out,
'max_in': max_in,
'std_out': std_out,
'std_in': std_in,
'max_diff': max_diff,
'z_in': z_in, ## z-score of in value over out distribution
'r_in': r_in,
}
outs = {k: val.cpu().numpy() if isinstance(val, torch.Tensor) else val for k, val in outs.items()}
return outs
def __call__(
self,
images: Union[np.ndarray, torch.Tensor],
):
"""
Calls the `score_alignment` method. See `self.score_alignment` docstring
for more info.
"""
return self.score_alignment(images)
[docs]
def make_2D_frequency_filter(
hw: tuple,
low: float = 5,
high: float = 6,
order: int = 3,
distance_p: int = 100,
):
"""
Make a filter for scoring the alignment of images using phase correlation.
RH 2024
Args:
hw (tuple):
Height and width of the images.
low (float):
Low cutoff frequency for the bandpass filter. Units are in
pixels.
high (float):
High cutoff frequency for the bandpass filter. Units are in
pixels.
order (int):
Order of the butterworth bandpass filter. (Default is *3*)
distance_p (int):
Distance parameter for the distance grid. Defines the Minkowski
distance used to compute the distance grid.
Returns:
(np.ndarray):
Filter for scoring the alignment. Shape: *(height, width)*
"""
## Make a distance grid starting from the fftshifted center
grid = make_distance_grid(shape=hw, p=distance_p, use_fftshift_center=True)
## Make the number of datapoints for the kernel large
n_x = max(hw) * 10
fs = max(hw) * 1
low = max(0, low)
high = min((max(hw) / 2) - 1, high)
b, a = design_butter_bandpass(lowcut=low, highcut=high, fs=fs, order=order, plot_pref=False)
w, h = scipy.signal.freqz(b, a, worN=n_x)
x_kernel = (fs * 0.5 / np.pi) * w
kernel = np.abs(h)
## Interpolate the kernel to the distance grid
filt = np.interp(
x=grid,
xp=x_kernel,
fp=kernel,
)
return filt
[docs]
def phase_correlation(
im_template: Union[np.ndarray, torch.Tensor],
im_moving: Union[np.ndarray, torch.Tensor],
mask_fft: Optional[Union[np.ndarray, torch.Tensor]] = None,
return_filtered_images: bool = False,
eps: float = 1e-8,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""
Perform phase correlation on two images. Calculation performed along the
last two axes of the input arrays (-2, -1) corresponding to the (height,
width) of the images.
RH 2024
Args:
im_template (np.ndarray):
The template image(s). Shape: (..., height, width). Can be any
number of dimensions; last two dimensions must be height and width.
im_moving (np.ndarray):
The moving image. Shape: (..., height, width). Leading dimensions
must broadcast with the template image.
mask_fft (Optional[np.ndarray]):
2D array mask for the FFT. If ``None``, no mask is used. Assumes mask_fft is
fftshifted. (Default is ``None``)
return_filtered_images (bool):
If set to ``True``, the function will return filtered images in
addition to the phase correlation coefficient. (Default is
``False``)
eps (float):
Epsilon value to prevent division by zero. (Default is ``1e-8``)
Returns:
(Tuple[np.ndarray, np.ndarray, np.ndarray]): tuple containing:
cc (np.ndarray):
The phase correlation coefficient.
fft_template (np.ndarray):
The filtered template image. Only returned if
return_filtered_images is ``True``.
fft_moving (np.ndarray):
The filtered moving image. Only returned if
return_filtered_images is ``True``.
"""
fft2, fftshift, ifft2 = torch.fft.fft2, torch.fft.fftshift, torch.fft.ifft2
abs, conj = torch.abs, torch.conj
axes = (-2, -1)
return_numpy = isinstance(im_template, np.ndarray)
im_template = torch.as_tensor(im_template)
im_moving = torch.as_tensor(im_moving)
fft_template = fft2(im_template, dim=axes)
fft_moving = fft2(im_moving, dim=axes)
if mask_fft is not None:
mask_fft = torch.as_tensor(mask_fft)
# Normalize and shift the mask
mask_fft = fftshift(mask_fft, dim=axes)
mask = mask_fft[tuple([None] * (im_template.ndim - 2) + [slice(None)] * 2)]
fft_template *= mask
fft_moving *= mask
# Compute the cross-power spectrum
R = fft_template * conj(fft_moving)
# Normalize to obtain the phase correlation function
R /= abs(R) + eps # Add epsilon to prevent division by zero
# Compute the magnitude of the inverse FFT to ensure symmetry
# cc = abs(fftshift(ifft2(R, dim=axes), dim=axes))
# Compute the real component of the inverse FFT (not symmetric)
cc = fftshift(ifft2(R, dim=axes), dim=axes).real
if return_filtered_images == False:
return cc.cpu().numpy() if return_numpy else cc
else:
if return_numpy:
return (
cc.cpu().numpy(),
abs(ifft2(fft_template, dim=axes)).cpu().numpy(),
abs(ifft2(fft_moving, dim=axes)).cpu().numpy()
)
else:
return cc, abs(ifft2(fft_template, dim=axes)), abs(ifft2(fft_moving, dim=axes))
######################################################################################################################################
######################################################## TIME SERIES #################################################################
######################################################################################################################################
[docs]
class Convolver_1d():
"""
Class for 1D convolution.
Uses torch.nn.functional.conv1d.
Stores the convolution and edge correction kernels for repeated use.
RH 2023
Attributes:
pad_mode (str):
Mode for padding. See ``torch.nn.functional.conv1d`` for details.
dtype (object):
Data type for the convolution. Default is ``torch.float32``.
kernel (object):
Convolution kernel as a tensor.
trace_correction (object):
Kernel for edge correction.
Args:
kernel (Union[np.ndarray, object]):
1D array to convolve with. The array can be a numpy array or a
tensor.
length_x (Optional[int]):
Length of the array to be convolved.
Must not be ``None`` if pad_mode is not 'valid'. (Default is
``None``)
dtype (object):
Data type to use for the convolution.
(Default is ``torch.float32``)
pad_mode (str):
Mode for padding.
See ``torch.nn.functional.conv1d`` for details.
(Default is 'same')
correct_edge_effects (bool):
Whether or not to correct for edge effects.
(Default is ``True``)
device (str):
Device to use for computation.
(Default is 'cpu')
"""
def __init__(
self,
kernel: Union[np.ndarray, object],
length_x: Optional[int] = None,
dtype: object = torch.float32,
pad_mode: str = 'same',
correct_edge_effects: bool = True,
device: str = 'cpu',
):
"""
Initializes the Convolver_1d with the given kernel, length of array to
be convolved, dtype, padding mode, edge effect correction setting, and
device.
"""
self.pad_mode = pad_mode
self.dtype = dtype
## convert kernel to torch tensor
self.kernel = torch.as_tensor(kernel, dtype=dtype, device=device)[None,None,:]
## compute edge correction kernel
if pad_mode != 'valid':
assert length_x is not None, "Must provide length_x if pad_mode is not 'valid'"
assert length_x >= kernel.shape[0], "length_x must be >= kernel.shape[0]"
self.trace_correction = torch.conv1d(
input=torch.ones((1,1,length_x), dtype=dtype, device=device),
weight=self.kernel,
padding=pad_mode,
)[0,0,:] if correct_edge_effects else None
else:
self.trace_correction = None
self.__call__ = self.convolve
[docs]
def convolve(self, arr: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""
Convolve array with kernel.
Args:
arr (Union[np.ndarray, torch.Tensor]):
Array to convolve.
Convolution performed along the last axis.
Must be 1D, 2D, or 3D.
Returns:
(Union[np.ndarray, torch.Tensor]):
out (Union[np.ndarray, torch.Tensor]):
The output tensor after performing convolution and
correcting for edge effects.
Example:
.. highlight:: python
.. code-block:: python
convolver = Convolver_1d(kernel=my_kernel)
result = convolver.convolve(my_array)
"""
## make array 3D by adding singleton dimensions if necessary
ndim = arr.ndim
if ndim == 1:
arr = arr[None,None,:]
elif ndim == 2:
arr = arr[None,:,:]
assert arr.ndim == 3, "Array must be 1D or 2D or 3D"
## convolve along last axis
out = torch.conv1d(
input=torch.as_tensor(arr, dtype=self.dtype, device=self.kernel.device),
weight=self.kernel,
padding=self.pad_mode,
)
## correct for edge effects
if self.trace_correction is not None:
out = out / self.trace_correction[None,None,:]
## remove singleton dimensions if necessary
if ndim == 1:
out = out[0,0,:]
elif ndim == 2:
out = out[0,:,:]
return out
def __repr__(self) -> str:
return f"Convolver_1d(kernel shape={self.kernel.shape}, pad_mode={self.pad_mode})"
######################################################################################################################################
########################################################## SPECTRAL ##################################################################
######################################################################################################################################
[docs]
def design_butter_bandpass(lowcut, highcut, fs, order=5, plot_pref=False):
'''
designs a butterworth bandpass filter.
Makes a lowpass filter if lowcut is 0.
Makes a highpass filter if highcut is fs/2.
RH 2021
Args:
lowcut (scalar):
frequency (in Hz) of low pass band
highcut (scalar):
frequency (in Hz) of high pass band
fs (scalar):
sample rate (frequency in Hz)
order (int):
order of the butterworth filter
Returns:
b (ndarray):
Numerator polynomial coeffs of the IIR filter
a (ndarray):
Denominator polynomials coeffs of the IIR filter
'''
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
if low <= 0:
## Make a lowpass filter
b, a = scipy.signal.butter(N=order, Wn=high, btype='low')
elif high >= 1:
## Make a highpass filter
b, a = scipy.signal.butter(N=order, Wn=low, btype='high')
else:
b, a = scipy.signal.butter(N=order, Wn=[low, high], btype='band')
if plot_pref:
plot_digital_filter_response(b=b, a=a, fs=fs, worN=100000)
return b, a
[docs]
def plot_digital_filter_response(b, a=None, fs=30, worN=100000, plot_pref=True):
'''
plots the frequency response of a digital filter
RH 2021
Args:
b (ndarray):
Numerator polynomial coeffs of the IIR filter
a (ndarray):
Denominator polynomials coeffs of the IIR filter
fs (scalar):
sample rate (frequency in Hz)
worN (int):
number of frequencies at which to evaluate the filter
'''
w, h = scipy.signal.freqz(b, a, worN=worN) if a is not None else scipy.signal.freqz(b, worN=worN)
xAxis = (fs * 0.5 / np.pi) * w
if plot_pref:
plt.figure()
plt.plot(xAxis, abs(h))
plt.xlabel('frequency (Hz)')
plt.ylabel('frequency response (a.u)')
plt.xscale('log')
return xAxis, abs(h)
######################################################################################################################################
####################################################### FEATURIZATION ################################################################
######################################################################################################################################
[docs]
class Toeplitz_convolution2d():
"""
Convolve a 2D array with a 2D kernel using the Toeplitz matrix
multiplication method. This class is ideal when 'x' is very sparse
(density<0.01), 'x' is small (shape <(1000,1000)), 'k' is small (shape
<(100,100)), and the batch size is large (e.g. 1000+). Generally, it is
faster than scipy.signal.convolve2d when convolving multiple arrays with the
same kernel. It maintains a low memory footprint by storing the toeplitz
matrix as a sparse matrix.
RH 2022
Attributes:
x_shape (Tuple[int, int]):
The shape of the 2D array to be convolved.
k (np.ndarray):
2D kernel to convolve with.
mode (str):
Either ``'full'``, ``'same'``, or ``'valid'``. See
scipy.signal.convolve2d for details.
dtype (Optional[np.dtype]):
The data type to use for the Toeplitz matrix.
If ``None``, then the data type of the kernel is used.
Args:
x_shape (Tuple[int, int]):
The shape of the 2D array to be convolved.
k (np.ndarray):
2D kernel to convolve with.
mode (str):
Convolution method to use, either ``'full'``, ``'same'``, or
``'valid'``.
See scipy.signal.convolve2d for details. (Default is 'same')
dtype (Optional[np.dtype]):
The data type to use for the Toeplitz matrix. Ideally, this matches
the data type of the input array. If ``None``, then the data type of
the kernel is used. (Default is ``None``)
Example:
.. highlight:: python
.. code-block:: python
# create Toeplitz_convolution2d object
toeplitz_convolution2d = Toeplitz_convolution2d(
x_shape=(100,30),
k=np.random.rand(10,10),
mode='same',
)
toeplitz_convolution2d(
x=scipy.sparse.csr_array(np.random.rand(5,3000)),
batch_size=True,
)
"""
def __init__(
self,
x_shape: Tuple[int, int],
k: np.ndarray,
mode: str = 'same',
dtype: Optional[np.dtype] = None,
):
"""
Initializes the Toeplitz_convolution2d object and stores the Toeplitz
matrix.
"""
self.k = k = np.flipud(k.copy())
self.mode = mode
self.x_shape = x_shape
dtype = k.dtype if dtype is None else dtype
if mode == 'valid':
assert x_shape[0] >= k.shape[0] and x_shape[1] >= k.shape[1], "x must be larger than k in both dimensions for mode='valid'"
self.so = so = size_output_array = ( (k.shape[0] + x_shape[0] -1), (k.shape[1] + x_shape[1] -1)) ## 'size out' is the size of the output array
## make the toeplitz matrices
t = toeplitz_matrices = [scipy.sparse.diags_array(
np.ones((k.shape[1], x_shape[1]), dtype=dtype) * k_i[::-1][:,None],
offsets=np.arange(-k.shape[1]+1, 1),
shape=(so[1], x_shape[1]),
dtype=dtype,
) for k_i in k[::-1]] ## make the toeplitz matrices for the rows of the kernel
tc = toeplitz_concatenated = scipy.sparse.vstack(t + [scipy.sparse.dia_array((t[0].shape), dtype=dtype)]*(x_shape[0]-1)) ## add empty matrices to the bottom of the block due to padding, then concatenate
## make the double block toeplitz matrix
self.dt = double_toeplitz = scipy.sparse.hstack([self._roll_sparse(
x=tc,
shift=(ii>0)*ii*(so[1]) ## shift the blocks by the size of the output array
) for ii in range(x_shape[0])]).tocsr()
def __call__(
self,
x: Union[np.ndarray, scipy.sparse.csc_array, scipy.sparse.csr_array],
batching: bool = True,
mode: Optional[str] = None,
) -> Union[np.ndarray, scipy.sparse.csr_array]:
"""
Convolve the input array with the kernel.
Args:
x (Union[np.ndarray, scipy.sparse.csc_array,
scipy.sparse.csr_array]):
Input array(s) (i.e. image(s)) to convolve with the kernel. \n
* If ``batching==False``: Single 2D array to convolve with the
kernel. Shape: *(self.x_shape[0], self.x_shape[1])*
* If ``batching==True``: Multiple 2D arrays that have been
flattened into row vectors (with order='C'). \n
Shape: *(n_arrays, self.x_shape[0]*self.x_shape[1])*
batching (bool):
* ``False``: x is a single 2D array.
* ``True``: x is a 2D array where each row is a flattened 2D
array. \n
(Default is ``True``)
mode (Optional[str]):
Defines the mode of the convolution. Options are 'full', 'same'
or 'valid'. See `scipy.signal.convolve2d` for details. Overrides
the mode set in __init__. (Default is ``None``)
Returns:
(Union[np.ndarray, scipy.sparse.csr_array]):
out (Union[np.ndarray, scipy.sparse.csr_array]):
* ``batching==True``: Multiple convolved 2D arrays that have
been flattened into row vectors (with order='C'). Shape:
*(n_arrays, height*width)*
* ``batching==False``: Single convolved 2D array of shape
*(height, width)*
"""
if mode is None:
mode = self.mode ## use the mode that was set in the init if not specified
issparse = scipy.sparse.issparse(x)
if batching:
x_v = x.T ## transpose into column vectors
else:
x_v = x.reshape(-1, 1) ## reshape 2D array into a column vector
if issparse:
x_v = x_v.tocsc()
out_v = self.dt @ x_v ## if sparse, then 'out_v' will be a csc matrix
## crop the output to the correct size
if mode == 'full':
p_t = 0
p_b = self.so[0]+1
p_l = 0
p_r = self.so[1]+1
if mode == 'same':
p_t = (self.k.shape[0]-1)//2
p_b = -(self.k.shape[0]-1)//2
p_l = (self.k.shape[1]-1)//2
p_r = -(self.k.shape[1]-1)//2
p_b = self.x_shape[0]+1 if p_b==0 else p_b
p_r = self.x_shape[1]+1 if p_r==0 else p_r
if mode == 'valid':
p_t = (self.k.shape[0]-1)
p_b = -(self.k.shape[0]-1)
p_l = (self.k.shape[1]-1)
p_r = -(self.k.shape[1]-1)
p_b = self.x_shape[0]+1 if p_b==0 else p_b
p_r = self.x_shape[1]+1 if p_r==0 else p_r
if batching:
idx_crop = np.zeros((self.so), dtype=bool)
idx_crop[p_t:p_b, p_l:p_r] = True
idx_crop = idx_crop.reshape(-1)
out = out_v[idx_crop,:].T
else:
if issparse:
out = out_v.reshape((self.so)).tocsc()[p_t:p_b, p_l:p_r]
else:
out = out_v.reshape((self.so))[p_t:p_b, p_l:p_r] ## reshape back into 2D array and crop
return out
def _roll_sparse(
self,
x: scipy.sparse.csr_array,
shift: int,
):
"""
Roll columns of a sparse matrix.
"""
out = x.copy()
out.row += shift
return out
[docs]
def make_distance_grid(shape=(512,512), p=2, idx_center=None, return_axes=False, use_fftshift_center=False):
"""
Creates a matrix of distances from the center.
Can calculate the Minkowski distance for any p.
RH 2023
Args:
shape (Tuple[int, int, ...]):
Shape of the n-dimensional grid (i,j,k,...)
If a shape value is odd, the center will be the center
of that dimension. If a shape value is even, the center
will be between the two center points.
p (int):
Order of the Minkowski distance.
p=1 is the Manhattan distance
p=2 is the Euclidean distance
p=inf is the Chebyshev distance
idx_center Optional[Tuple[int, int, ...]]:
The index of the center of the grid. If None, the center is
assumed to be the center of the grid. If provided, the center
will be set to this index. This is useful for odd shaped grids
where the center is not obvious.
return_axes (bool):
If True, return the axes of the grid as well. Return will be a
tuple.
use_fft_center (bool):
If True, the center of the grid will be the center of the FFT
grid. This is useful for FFT operations where the center is
assumed to be the top left corner.
Returns:
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
distance_image (np.ndarray):
array of distances to the center index
axes (Optional[np.ndarray]):
axes of the grid as well. Only returned if return_axes=True
"""
if use_fftshift_center:
## Find idx wheter freq=0. Use np.fft.fftfreq
freqs_h, freqs_w = np.fft.fftshift(np.fft.fftfreq(shape[0])), np.fft.fftshift(np.fft.fftfreq(shape[1]))
idx_center = (np.argmin(np.abs(freqs_h)), np.argmin(np.abs(freqs_w)))
shape = np.array(shape)
if idx_center is not None:
axes = [np.linspace(-idx_center[i], shape[i] - idx_center[i] - 1, shape[i]) for i in range(len(shape))]
else:
axes = [np.arange(-(d - 1) / 2, (d - 1) / 2 + 0.5) for d in shape]
grid = np.stack(
np.meshgrid(*axes, indexing="ij"),
axis=0,
)
if idx_center is not None:
grid_dist = np.linalg.norm(
grid ,
ord=p,
axis=0,
)
else:
grid_dist = np.linalg.norm(
grid,
ord=p,
axis=0,
)
return grid_dist if not return_axes else (grid_dist, axes)
######################################################################################################################################
##################################################### PARALLEL HELPERS ###############################################################
######################################################################################################################################
[docs]
class ParallelExecutionError(Exception):
"""
Exception class for errors that occur during parallel execution.
Intended to be used with the ``map_parallel`` function.
RH 2023
Attributes:
index (int):
Index of the job that failed.
original_exception (Exception):
The original exception that was raised.
"""
def __init__(self, index, original_exception):
self.index = index
self.original_exception = original_exception
def __str__(self):
return f"Job {self.index} raised an exception: {self.original_exception}"
[docs]
def map_parallel(
func: Callable,
args: List[Any],
method: str = 'multithreading',
n_workers: int = -1,
prog_bar: bool = True
) -> List[Any]:
"""
Maps a function to a list of arguments in parallel.
RH 2022
Args:
func (Callable):
The function to be mapped.
args (List[Any]):
List of arguments to which the function should be mapped.
Length of list should be equal to the number of arguments.
Each element should then be an iterable for each job that is run.
method (str):
Method to use for parallelization. Either \n
* ``'multithreading'``: Use multithreading from concurrent.futures.
* ``'multiprocessing'``: Use multiprocessing from concurrent.futures.
* ``'mpire'``: Use mpire.
* ``'serial'``: Use list comprehension. \n
(Default is ``'multithreading'``)
workers (int):
Number of workers to use. If set to -1, all available workers are used. (Default is ``-1``)
prog_bar (bool):
Whether to display a progress bar using tqdm. (Default is ``True``)
Returns:
(List[Any]):
output (List[Any]):
List of results from mapping the function to the arguments.
Example:
.. highlight::python
.. code-block::python
result = map_parallel(max, [[1,2,3,4],[5,6,7,8]], method='multiprocessing', n_workers=3)
"""
if n_workers == -1:
n_workers = mp.cpu_count()
## Assert that args is a list
assert isinstance(args, list), "args must be a list"
## Assert that each element of args is an iterable
assert all([hasattr(arg, '__iter__') for arg in args]), "All elements of args must be iterable"
## Assert that each element has a length
assert all([hasattr(arg, '__len__') for arg in args]), "All elements of args must have a length"
## Get number of arguments. If args is a generator, make None.
n_args = len(args[0]) if hasattr(args, '__len__') else None
## Assert that all args are the same length
assert all([len(arg) == n_args for arg in args]), "All args must be the same length"
## Make indices
indices = np.arange(n_args)
## Prepare args_map (input to map function)
args_map = [[func] * n_args, *args, indices]
if method == 'multithreading':
executor = ThreadPoolExecutor
elif method == 'multiprocessing':
executor = ProcessPoolExecutor
elif method == 'mpire':
import mpire
executor = mpire.WorkerPool
# elif method == 'joblib':
# import joblib
# return joblib.Parallel(n_jobs=workers)(joblib.delayed(func)(arg) for arg in tqdm(args, total=n_args, disable=prog_bar!=True))
elif method == 'serial':
return list(tqdm(map(_func_wrapper_helper, *args_map), total=n_args, disable=prog_bar!=True))
else:
raise ValueError(f"method {method} not recognized")
with executor(n_workers) as ex:
return list(tqdm(ex.map(_func_wrapper_helper, *args_map), total=n_args, disable=prog_bar!=True))
def _func_wrapper_helper(*func_args_index):
"""
Wrapper function to catch exceptions.
Args:
*func_args_index (tuple):
Tuple of arguments to be passed to the function.
Should take the form of (func, arg1, arg2, ..., argN, index)
The last element is the index of the job.
"""
func = func_args_index[0]
args = func_args_index[1:-1]
index = func_args_index[-1]
try:
return func(*args)
except Exception as e:
raise ParallelExecutionError(index, e)
######################################################################################################################################
######################################################### CLUSTERING #################################################################
######################################################################################################################################
[docs]
def compute_cluster_similarity_matrices(
s: Union[scipy.sparse.csr_array, np.ndarray, sparse.COO],
l: np.ndarray,
verbose: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Computes the similarity matrices for each cluster in ``l``. This algorithm
works best on large and sparse matrices.
RH 2023
Args:
s (Union[scipy.sparse.csr_array, np.ndarray, sparse.COO]):
Similarity matrix. Entries should be non-negative floats.
l (np.ndarray):
Labels for each row of ``s``. Labels should ideally be integers.
verbose (bool):
Whether to print warnings. (Default is ``True``)
Returns:
(tuple): tuple containing:
labels_unique (np.ndarray):
Unique labels in ``l``.
cs_mean (np.ndarray):
Similarity matrix for each cluster. Each element is the mean
similarity between all the pairs of samples in each cluster.
**Note** that the diagonal here only considers non-self
similarity, which excludes the diagonals of ``s``.
cs_max (np.ndarray):
Similarity matrix for each cluster. Each element is the maximum
similarity between all the pairs of samples in each cluster.
**Note** that the diagonal here only considers non-self
similarity, which excludes the diagonals of ``s``.
cs_min (np.ndarray):
Similarity matrix for each cluster. Each element is the minimum
similarity between all the pairs of samples in each cluster.
Will be 0 if there are any sparse elements between the two
clusters.
"""
import sparse
import scipy.sparse
l_arr = np.array(l)
ss = scipy.sparse.csr_array(s.astype(np.float32))
## assert that all labels have at least two samples
l_u ,l_c = np.unique(l_arr, return_counts=True)
# assert np.all(l_c >= 2), "All labels must have at least two samples."
## assert that s is a square matrix
assert ss.shape[0] == ss.shape[1], "Similarity matrix must be square."
## assert that s is non-negative
assert (ss < 0).sum() == 0, "Similarity matrix must be non-negative."
## assert that l is a 1-D array
assert len(l.shape) == 1, "Labels must be a 1-D array."
## assert that l is the same length as s
assert len(l) == ss.shape[0], "Labels must be the same length as the similarity matrix."
if verbose:
## Warn if s is not symmetric
if not (ss - ss.T).sum() == 0:
print("Warning: Similarity matrix is not symmetric.") if verbose else None
## Warn if s is not sparse
if not (isinstance(ss, (np.ndarray, sparse.COO)) or scipy.sparse.issparse(ss)):
print("Warning: Similarity matrix is not a recognized sparse type or np.ndarray. Will attempt to convert to sparse.COO") if verbose else None
## Warn if diagonal is not all ones. It will be converted
if not np.allclose(np.array(ss[range(ss.shape[0]), range(ss.shape[0])]), 1):
print("Warning: Similarity matrix diagonal is not all ones. Will set diagonal to all ones.") if verbose else None
## Warn if there are any values greater than 1
if (ss > 1).sum() > 0:
print("Warning: Similarity matrix has values greater than 1.") if verbose else None
## Warn if there are NaNs. Set to 0.
if (np.isnan(ss.data)).sum() > 0:
print("Warning: Similarity matrix has NaNs. Will set to 0.") if verbose else None
ss.data[np.isnan(ss.data)] = 0
## Make a boolean matrix for labels
l_bool = sparse.COO(np.stack([l_arr == u for u in l_u], axis=0))
samp_per_clust = l_bool.sum(1).todense()
n_clusters = len(samp_per_clust)
n_samples = ss.shape[0]
## Force diagonal to be 1s
ss = ss.tolil()
ss[range(n_samples), range(n_samples)] = 1
ss = sparse.COO(ss)
## Compute the similarity matrix for each pair of clusters
s_big_conj = ss[None,None,:,:] * l_bool[None,:,:,None] * l_bool[:,None,None,:] ## shape: (n_clusters, n_clusters, n_samples, n_samples)
s_big_diag = sparse.eye(n_samples) * l_bool[None,:,:,None] * l_bool[:,None,None,:]
## Compute the mean similarity matrix for each cluster
samp_per_clust_crossGrid = samp_per_clust[:,None] * samp_per_clust[None,:] ## shape: (n_clusters, n_clusters). This is the product of the number of samples in each cluster. Will be used to divide by the sum of similarities.
norm_mat = samp_per_clust_crossGrid.copy() ## above variable will be used again and this one will be mutated.
fixed_diag = samp_per_clust * (samp_per_clust - 1) ## shape: (n_clusters,). For the diagonal, we need to subtract 1 from the number of samples in each cluster because samples have only 1 similarity with themselves along the diagonal.
norm_mat[range(n_clusters), range(n_clusters)] = fixed_diag ## Correcting the diagonal
s_big_sum_raw = s_big_conj.sum(axis=(2,3)).todense()
s_big_sum_raw[range(n_clusters), range(n_clusters)] = s_big_sum_raw[range(n_clusters), range(n_clusters)] - samp_per_clust ## subtract off the number of samples in each cluster from the diagonal
cs_mean = s_big_sum_raw / norm_mat ## shape: (n_clusters, n_clusters). Compute mean by finding the sum of the similarities and dividing by the norm_mat.
## Compute the min similarity matrix for each cluster
### This is done in two steps:
#### 1. Compute the minimum similarity between each pair of clusters by inverting the similarity matrix and finding the maximum similarity between each pair of clusters.
#### 2. Since the first step doesn't invert any values that happen to be 0 (since they are sparse), we need to find out if there are any 0 values there are in each cluster pair, and if there then the minimum similarity between the two clusters is 0.
val_max = s_big_conj.max() + 1
cs_min = s_big_conj.copy()
cs_min.data = val_max - cs_min.data ## Invert the values
cs_min = cs_min.max(axis=(2,3)) ## Find the max similarity
cs_min.data = val_max - cs_min.data ## Invert the values back
cs_min.fill_value = 0.0 ## Set the fill value to 0.0 since it gets messed up by these subtraction operations
n_missing_values = (samp_per_clust_crossGrid - (s_big_conj > 0).sum(axis=(2,3)).todense()) ## shape: (n_clusters, n_clusters). Compute the number of missing values by subtracting the number of non-zero values from the number of samples in each cluster.
# n_missing_values[range(len(samp_per_clust)), range(len(samp_per_clust))] = (samp_per_clust**2 - samp_per_clust) - ((s_big_conj[range(len(samp_per_clust)), range(len(samp_per_clust))] > 0).sum(axis=(1,2))).todense() ## Correct the diagonal by subtracting the number of non-zero values from the number of samples in each cluster. This is because the diagonal is the number of samples in each cluster squared minus the number of samples in each cluster.
bool_nonMissing_values = (n_missing_values == 0) ## shape: (n_clusters, n_clusters). Make a boolean matrix for where there are no missing values.
cs_min = cs_min.todense() * bool_nonMissing_values ## Set the minimum similarity to 0 where there are missing values.
## Compute the max similarity matrix for each cluster
cs_max = (s_big_conj - s_big_diag).max(axis=(2,3))
return l_u, cs_mean, cs_max.todense(), cs_min
######################################################################################################################################
########################################################### STATS ####################################################################
######################################################################################################################################
[docs]
def zscore_to_pvalue(z, two_tailed=True):
"""
Convert a z-score to a p-value.
Args:
z (float):
The z-score.
two_tailed (bool):
If True, return a two-tailed p-value. If False, return a one-tailed
p-value.
Returns:
float:
The p-value.
"""
if two_tailed:
return 2 * scipy.stats.norm.sf(np.abs(z))
else:
return scipy.stats.norm.sf(np.abs(z))
[docs]
def pvalue_to_zscore(p, two_tailed=True):
"""
Convert a p-value to a z-score.
Args:
p (float):
The p-value.
two_tailed (bool):
If True, the p-value is two-tailed. If False, the p-value is one-tailed.
Returns:
float:
The z-score.
"""
if two_tailed:
return scipy.stats.norm.ppf(1 - p/2)
else:
return scipy.stats.norm.ppf(1 - p)
######################################################################################################################################
######################################################## SIMILARITY ##################################################################
######################################################################################################################################
[docs]
def get_path_between_nodes(
idx_start: int,
idx_end: int,
predecessors: np.ndarray,
max_length: int = 9999,
):
"""
Finds the indices corresponding to the shortest path between two nodes in a
graph. Uses a predecessor matrix from a shortest path algorithm (e.g.
scipy.sparse.csgraph.shortest_path)
RH 2024
Args:
idx_start (int):
Index of the starting node.
idx_end (int):
Index of the ending node.
predecessors (np.ndarray):
Predecessor matrix from a shortest path algorithm.
max_length (int):
Maximum length of the path. (Default is 9999)
Returns:
path (List[int]):
List of node indices corresponding to the shortest path from
idx_start to idx_end. [idx_start, ..., idx_end]
"""
## Check inputs
### Check that idx_start and idx_end are within the range of the predecessors matrix
assert idx_start < predecessors.shape[0], "idx_start is out of range"
assert idx_end < predecessors.shape[0], "idx_end is out of range"
### Check that the predecessors matrix is 2D
assert predecessors.ndim == 2, "predecessors matrix must be 2D"
### Check that the predecessors matrix is square
assert predecessors.shape[0] == predecessors.shape[1], "predecessors matrix must be square"
### Check that idx_start, idx_end, and max_length are integers
assert isinstance(idx_start, int), "idx_start must be an integer"
assert isinstance(idx_end, int), "idx_end must be an integer"
assert isinstance(max_length, int), "max_length must be an integer"
### Check that the path from idx_start to idx_end exists
assert predecessors[idx_end, idx_start] != -9999, f"Possibly no path exists. Found that path from {idx_start} to {idx_end} has value -9999 (predecessors[idx_end, idx_start] == {predecessors[idx_end, idx_start]}). This value is assumed to be a placeholder for no path."
## Initialize path
path = []
idx_current = int(idx_start)
path.append(idx_current)
## Traverse the predecessors matrix to find the shortest path
while idx_current != idx_end:
if len(path) > max_length:
raise ValueError("Path length exceeds max_length")
idx_current = int(predecessors[idx_end, idx_current])
path.append(idx_current)
return path
######################################################################################################################################
########################################################## TESTING ###################################################################
######################################################################################################################################
[docs]
class Equivalence_checker():
"""
Class for checking if all items are equivalent or allclose (almost equal) in
two complex data structures. Can check nested lists, dicts, and other data
structures. Can also optionally assert (raise errors) if all items are not
equivalent.
RH 2023
Attributes:
_kwargs_allclose (Optional[dict]):
Keyword arguments for the `numpy.allclose` function.
_assert_mode (bool):
Whether to raise an assertion error if items are not close.
Args:
kwargs_allclose (Optional[dict]):
Keyword arguments for the `numpy.allclose` function. (Default is
``{'rtol': 1e-7, 'equal_nan': True}``)
assert_mode (bool):
Whether to raise an assertion error if items are not close.
verbose (bool):
How much information to print out:
* ``False`` / ``0``: No information printed out.
* ``True`` / ``1``: Mismatched items only.
* ``2``: All items printed out.
"""
def __init__(
self,
kwargs_allclose: Optional[dict] = {'rtol': 1e-7, 'equal_nan': True},
assert_mode=False,
verbose=False,
) -> None:
"""
Initializes the Allclose_checker.
"""
self._kwargs_allclose = kwargs_allclose
self._assert_mode = assert_mode
self._verbose = verbose
def _normalize_leaf(self, value: Any) -> Any:
"""
Normalize common array-like leaf objects before comparison.
Torch tensors are converted to CPU NumPy arrays so they can be compared
with the same logic as serialized payloads loaded from RichFile.
"""
## This is intentionally a tiny normalization layer. The integration
## tests now do higher-level tree normalization before they call into
## Equivalence_checker, so this helper only needs to smooth over
## end-value array/tensor representation differences.
if torch.is_tensor(value):
return value.detach().cpu().numpy()
return value
def _compare_sparse(
self,
test: Any,
true: Any,
) -> Tuple[bool, str]:
"""
Compare sparse arrays without densifying and while correctly handling
nonzero ``atol`` for implicit zeros.
"""
## Canonicalize first so explicit zeros and storage format differences
## do not create false mismatches before we even inspect values.
test_csr = scipy.sparse.csr_array(test, copy=True)
true_csr = scipy.sparse.csr_array(true, copy=True)
test_csr.eliminate_zeros()
true_csr.eliminate_zeros()
if true_csr.shape != test_csr.shape:
return (False, f'shape mismatch: test={test_csr.shape}, true={true_csr.shape}')
if true_csr.dtype != test_csr.dtype:
return (False, f'dtype mismatch: test={test_csr.dtype}, true={true_csr.dtype}')
## Build the union of nonzero coordinates from both arrays. This is the
## key sparse-specific step: values that are not stored explicitly are
## still treated as zeros, and those zeros can still matter for atol.
## So we compare every position where either array stores a value.
test_mask = test_csr.copy()
true_mask = true_csr.copy()
test_mask.data = np.ones_like(test_mask.data, dtype=np.int8)
true_mask.data = np.ones_like(true_mask.data, dtype=np.int8)
union_mask = (test_mask + true_mask).astype(bool).tocoo()
if union_mask.nnz == 0:
return (True, 'equivalence')
rows, cols = union_mask.row, union_mask.col
## Pull only the positions we actually need instead of converting the
## whole sparse matrix into a dense one.
test_vals = np.asarray(test_csr[rows, cols]).reshape(-1)
true_vals = np.asarray(true_csr[rows, cols]).reshape(-1)
kwargs = {
'rtol': self._kwargs_allclose.get('rtol', 1e-7),
'atol': self._kwargs_allclose.get('atol', 0),
'equal_nan': self._kwargs_allclose.get('equal_nan', False),
}
is_close = np.isclose(test_vals, true_vals, **kwargs)
if np.all(is_close):
return (True, 'equivalence')
mismatched = ~is_close
## Report the amount by which the worst offending entry exceeded the
## allowed tolerance, not just the raw absolute difference.
abs_diff = np.abs(test_vals[mismatched] - true_vals[mismatched])
tol = kwargs['atol'] + kwargs['rtol'] * np.abs(true_vals[mismatched])
max_violation = float(np.max(abs_diff - tol))
n_mismatches = int(np.sum(mismatched))
n_elements = true_csr.shape[0] * true_csr.shape[1]
return (
False,
f"sparse allclose failed: n_mismatches={n_mismatches}, "
f"max_violation={max_violation:.3e}, n_elements={n_elements}",
)
def _checker(
self,
test: Any,
true: Any,
path: Optional[List[str]] = None,
) -> bool:
"""
Compares the test and true values using numpy's allclose function.
Args:
test (Union[dict, list, tuple, set, np.ndarray, int, float, complex,
str, bool, None]):
Test value to compare.
true (Union[dict, list, tuple, set, np.ndarray, int, float, complex,
str, bool, None]):
True value to compare.
path (Optional[List[str]]):
The path of the data structure that is currently being compared.
(Default is ``None``)
Returns:
(bool):
result (bool):
Returns True if all elements in test and true are close.
Otherwise, returns False.
"""
test = self._normalize_leaf(test)
true = self._normalize_leaf(true)
try:
## If the dtype is a kind of string (or byte string) or object, then allclose will raise an error. In this case, just check if the values are equal.
if np.issubdtype(test.dtype, np.str_) or np.issubdtype(test.dtype, np.bytes_) or test.dtype == np.object_:
out = bool(np.all(test == true))
else:
out = np.allclose(test, true, **self._kwargs_allclose)
except Exception as e:
out = None ## This is not False because sometimes allclose will raise an error if the arrays have a weird dtype among other reasons.
warnings.warn(f"WARNING. Equivalence check failed. Path: {path}. Error: {e}") if self._verbose else None
if out == False:
if self._assert_mode:
raise AssertionError(f"Equivalence check failed. Path: {path}.")
if self._verbose:
## Come up with a way to describe the difference between the two values. Something like the following:
### IF the arrays are numeric, then calculate the relative difference
dtypes_numeric = (np.number, bool, np.integer, np.floating, np.complexfloating)
if any([np.issubdtype(test.dtype, dtype) and np.issubdtype(true.dtype, dtype) for dtype in dtypes_numeric]):
## Cast booleans to int to avoid TypeError on subtraction in NumPy 2.0
test_num = test.astype(int) if np.issubdtype(test.dtype, bool) else test
true_num = true.astype(int) if np.issubdtype(true.dtype, bool) else true
diff = np.abs(test_num - true_num)
at = np.abs(true)
r_diff = diff / at if np.all(at != 0) else np.inf
r_diff_mean, r_diff_max, any_nan = np.nanmean(r_diff), np.nanmax(r_diff), np.any(np.isnan(r_diff))
## fraction of mismatches
n_elements = np.prod(test.shape)
n_mismatches = np.sum(diff > 0)
frac_mismatches = n_mismatches / n_elements
## Use scientific notation and round to 3 decimal places
reason = f"Equivalence: Relative difference: mean={r_diff_mean:.3e}, max={r_diff_max:.3e}, any_nan={any_nan}, n_elements={n_elements}, n_mismatches={n_mismatches}, frac_mismatches={frac_mismatches:.3e}"
else:
reason = f"Values are not numpy numeric types. types: {test.dtype}, {true.dtype}"
else:
reason = "allclose failed"
elif out is None:
reason = "check skipped (exception)"
else:
reason = "equivalence"
return out, reason
def __call__(
self,
test: Union[dict, list, tuple, set, np.ndarray, int, float, complex, str, bool, None],
true: Union[dict, list, tuple, set, np.ndarray, int, float, complex, str, bool, None],
path: Optional[List[str]] = None,
) -> Dict[str, Tuple[bool, str]]:
"""
Compares the test and true values and returns the comparison result.
Handles various data types including dictionaries, iterables,
np.ndarray, scalars, strings, numbers, bool, and None.
Args:
test (Union[dict, list, tuple, set, np.ndarray, int, float, complex,
str, bool, None]):
Test value to compare.
true (Union[dict, list, tuple, set, np.ndarray, int, float, complex,
str, bool, None]):
True value to compare.
path (Optional[List[str]]):
The path of the data structure that is currently being compared.
(Default is ``None``)
Returns:
Dict[Tuple[bool, str]]:
result Dict[Tuple[bool, str]]:
The comparison result as a dictionary or a tuple depending
on the data types of test and true.
"""
if path is None:
path = ['']
test = self._normalize_leaf(test)
true = self._normalize_leaf(true)
if len(path) > 0:
if path[-1].startswith('_'):
return (None, 'excluded from testing')
## TYPE MISMATCH: one side is None and the other is not
if (test is None) != (true is None):
result = (False, f'type mismatch: test is {type(test).__name__}, true is {type(true).__name__}')
## SCIPY SPARSE MATRIX
elif scipy.sparse.issparse(true):
if not scipy.sparse.issparse(test):
result = (False, f'type mismatch: test is {type(test).__name__}, true is {type(true).__name__}')
else:
## Keep sparse comparison on the dedicated path so we never
## accidentally densify large matrices through np.allclose.
result = self._compare_sparse(test=test, true=true)
## NP.NDARRAY
elif isinstance(true, np.ndarray):
result = self._checker(test, true, path)
## NP.SCALAR
elif np.isscalar(true):
if isinstance(test, (int, float, complex, np.number)):
result = self._checker(np.array(test), np.array(true), path)
else:
result = (test == true, 'equivalence')
## NUMBER
elif isinstance(true, (int, float, complex)):
result = self._checker(test, true, path)
## DICT
elif isinstance(true, dict):
if not isinstance(test, dict):
## test is not a dict — try to compare via __dict__ if available,
## otherwise report type mismatch. This is primarily to support
## comparing live helper objects against their serialized dict
## form when the meaningful saved data is still the same.
if hasattr(test, '__dict__'):
test = {k: v for k, v in test.__dict__.items() if not k.startswith('_')}
else:
return (False, f'type mismatch: test is {type(test).__name__}, true is dict')
result = {}
for key in true:
if key not in test:
result[str(key)] = (False, 'key not found')
else:
result[str(key)] = self.__call__(test[key], true[key], path=path + [str(key)])
## ITERATABLE
elif isinstance(true, (list, tuple, set)):
if len(true) != len(test):
result = (False, 'length_mismatch')
else:
if all([isinstance(i, (int, float, complex, np.number)) for i in true]):
result = self._checker(np.array(test), np.array(true), path)
else:
result = {}
for idx, (i, j) in enumerate(zip(test, true)):
result[str(idx)] = self.__call__(i, j, path=path + [str(idx)])
## STRING
elif isinstance(true, str):
result = (test == true, 'equivalence')
## BOOL
elif isinstance(true, bool):
result = (test == true, 'equivalence')
## NONE
elif true is None:
result = (test is None, 'equivalence')
## OBJECT with __dict__
elif hasattr(true, '__dict__'):
true_dict = {k: v for k, v in true.__dict__.items() if not k.startswith('_')}
if isinstance(test, dict):
## When the serialized side is a dict and the live side is an
## object, recurse back through the dict branch so both paths
## follow the same key-by-key comparison logic.
result = self.__call__(test, true_dict, path=path)
else:
result = {}
for key in true_dict:
if not hasattr(test, key):
result[str(key)] = (False, 'key not found')
else:
result[str(key)] = self.__call__(getattr(test, key), true_dict[key], path=path + [str(key)])
## N/A
else:
result = (None, 'not tested')
if isinstance(result, tuple):
if self._assert_mode:
assert (result[0] != False), f"Equivalence check failed. Path: {path}. Reason: {result[1]}"
if self._verbose > 0:
## Print False results
if result[0] == False:
print(f"Equivalence check failed. Path: {path}. Reason: {result[1]}")
if self._verbose > 1:
## Print True results
if result[0] == True:
print(f"Equivalence check passed. Path: {path}. Reason: {result[1]}")
elif result[0] is None:
print(f"Equivalence check not tested. Path: {path}. Reason: {result[1]}")
return result
######################################################################################################################################
###################################################### OTHER FUNCTIONS ###############################################################
######################################################################################################################################
[docs]
def get_balanced_class_weights(
labels: np.ndarray
) -> np.ndarray:
"""
Balances the weights for classifier training.
RH, JZ 2022
Args:
labels (np.ndarray):
Array that includes a list of labels to balance the weights for
classifier training. *shape: (n,)*
Returns:
(np.ndarray):
weights (np.ndarray):
Weights by samples. *shape: (n,)*
"""
labels = labels.astype(np.int64)
vals, counts = np.unique(labels, return_counts=True)
weights = len(labels) / counts
return weights
[docs]
def get_balanced_sample_weights(
labels: np.ndarray,
class_weights: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Balances the weights for classifier training.
RH/JZ 2022
Args:
labels (np.ndarray):
Array that includes a list of labels to balance the weights for
classifier training. *shape: (n,)*
class_weights (np.ndarray, Optional):
Optional parameter which includes an array of pre-fit class weights.
If ``None``, weights will be calculated using the provided sample
labels. (Default is ``None``)
Returns:
(np.ndarray):
sample_weights (np.ndarray):
Sample weights by labels. *shape: (n,)*
"""
if type(class_weights) is not np.ndarray and type(class_weights) is not np.array:
print('Warning: Class weights not pre-fit. Using provided sample labels.')
weights = get_balanced_class_weights(labels)
else:
weights = class_weights
sample_weights = weights[labels]
return sample_weights
[docs]
def safe_set_attr(
obj: Any,
attr: str,
value: Any,
overwrite: bool = False,
) -> None:
"""
Safely sets an attribute on an object. If the attribute is not present, it
will be created. If the attribute is present, it will only be overwritten if
``overwrite`` is set to ``True``.
RH 2024
Args:
obj (Any):
Object to set the attribute on.
attr (str):
Attribute name.
value (Any):
Value to set the attribute to.
overwrite (bool):
Whether to overwrite the attribute if it already exists.
(Default is ``False``)
"""
if not hasattr(obj, attr):
setattr(obj, attr, value)
elif overwrite:
setattr(obj, attr, value)
[docs]
def reshape_coo_manual(coo, new_shape):
"""
Manually reshape a COO matrix using 64-bit arithmetic.
This function is only needed because windows does a bad
job of gracefully switching from int32 to int64 when the
values of idx need to be greater than the int32 max value
(2147483648).
Andrew helped figure this one out. It was an issue for
when there are > around 30k ROIs.
RH 2025
Parameters:
coo : scipy.sparse.coo_array
The input sparse matrix in COO format.
new_shape : tuple of ints
The desired shape, e.g. (1, -1) expanded to a complete tuple.
Returns:
new_coo : scipy.sparse.coo_array
The reshaped COO matrix.
"""
# Ensure new_shape is fully specified.
# If one of the dimensions is -1, compute it from the other.
if -1 in new_shape:
known_dim = [d for d in new_shape if d != -1][0]
total_elements = np.int64(coo.shape[0]) * np.int64(coo.shape[1])
new_dim = total_elements // np.int64(known_dim)
new_shape = tuple(new_dim if d == -1 else d for d in new_shape)
# Check that the total number of elements is the same.
old_size = np.int64(coo.shape[0]) * np.int64(coo.shape[1])
new_size = np.int64(new_shape[0]) * np.int64(new_shape[1])
if old_size != new_size:
raise ValueError("Cannot reshape matrix of total size {} into shape {}".format(old_size, new_shape))
# Convert row and col to int64
old_row = coo.row.astype(np.int64, copy=False)
old_col = coo.col.astype(np.int64, copy=False)
# Calculate flattened indices using C-order: flat_index = row * n_cols + col
n_cols_old = np.int64(coo.shape[1])
flat_idx = old_row * n_cols_old + old_col
# Now compute new row and col from the flat index:
n_cols_new = np.int64(new_shape[1])
new_row = flat_idx // n_cols_new
new_col = flat_idx % n_cols_new
# Create a new COO matrix with the new indices.
return scipy.sparse.coo_array((coo.data, (new_row, new_col)), shape=new_shape)