from typing import List, Tuple, Union, Optional, Dict, Any, Callable, Iterable
import copy
import os
import copy
import os
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy.sparse
import torch
import pandas as pd
from . import util, helpers
[docs]
def display_toggle_image_stack(
images: Union[List[np.ndarray], List[torch.Tensor]],
image_size: Optional[Union[Tuple[int, int], int, float]] = None,
clim: Optional[Tuple[float, float]] = None,
interpolation: str = 'nearest',
) -> None:
"""
Displays images in a slider using Jupyter Notebook.
RH 2023
Args:
images (Union[List[np.ndarray], List[torch.Tensor]]):
List of images as numpy arrays or PyTorch tensors.
image_size (Optional[Tuple[int, int]]):
Tuple of *(width, height)* for resizing images.\n
If ``None``, images are not resized.\n
If a single integer or float is provided, the images are resized by
that factor.\n
(Default is ``None``)
clim (Optional[Tuple[float, float]]):
Tuple of *(min, max)* values for scaling pixel intensities. If
``None``, min and max values are computed from the images and used
as bounds for scaling. (Default is ``None``)
interpolation (str):
String specifying the interpolation method for resizing. Options are
'nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'. Uses
the Image.Resampling.* methods from PIL. (Default is 'nearest')
"""
from IPython.display import display, HTML
import numpy as np
import base64
from PIL import Image
from io import BytesIO
import torch
import datetime
import hashlib
import sys
# Get the image size for display
if image_size is None:
image_size = images[0].shape[:2]
elif isinstance(image_size, (int, float)):
image_size = tuple((np.array(images[0].shape[:2]) * image_size).astype(np.int64))
elif isinstance(image_size, (tuple, list)):
image_size = tuple(image_size)
else:
raise ValueError("Invalid image size. Must be a tuple of (width, height) or a single integer or float.")
def normalize_image(image, clim=None):
"""Normalize the input image using the min-max scaling method. Optionally, use the given clim values for scaling."""
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
if clim is None:
clim = (np.min(image), np.max(image))
norm_image = (image - clim[0]) / (clim[1] - clim[0])
norm_image = np.clip(norm_image, 0, 1)
return (norm_image * 255).astype(np.uint8)
def resize_image(image, new_size, interpolation):
"""Resize the given image to the specified new size using the specified interpolation method."""
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
pil_image = Image.fromarray(image.astype(np.uint8))
resized_image = pil_image.resize(new_size, resample=interpolation)
return np.array(resized_image)
def numpy_to_base64(numpy_array):
"""Convert a numpy array to a base64 encoded string."""
img = Image.fromarray(numpy_array.astype('uint8'))
buffered = BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("ascii")
def process_image(image):
"""Normalize, resize, and convert image to base64."""
# Normalize image
norm_image = normalize_image(image, clim)
# Resize image if requested
if image_size is not None:
norm_image = resize_image(norm_image, image_size, interpolation_method)
# Convert image to base64
return numpy_to_base64(norm_image)
# Check if being called from a Jupyter notebook
if 'ipykernel' not in sys.modules:
raise RuntimeError("This function must be called from a Jupyter notebook.")
# Create a dictionary to map interpolation string inputs to Image objects
interpolation_methods = {
'nearest': Image.Resampling.NEAREST,
'box': Image.Resampling.BOX,
'bilinear': Image.Resampling.BILINEAR,
'hamming': Image.Resampling.HAMMING,
'bicubic': Image.Resampling.BICUBIC,
'lanczos': Image.Resampling.LANCZOS,
}
# Check if provided interpolation method is valid
if interpolation not in interpolation_methods:
raise ValueError("Invalid interpolation method. Choose from 'nearest', 'box', 'bilinear', 'hamming', 'bicubic', or 'lanczos'.")
# Get the actual Image object for the specified interpolation method
interpolation_method = interpolation_methods[interpolation]
# Generate a unique identifier for the slider
slider_id = hashlib.sha256(str(datetime.datetime.now()).encode()).hexdigest()
# Process all images in the input list
base64_images = [process_image(img) for img in images]
# Generate the HTML code for the slider
html_code = f"""
<div>
<input type="range" id="imageSlider_{slider_id}" min="0" max="{len(base64_images) - 1}" value="0">
<img id="displayedImage_{slider_id}" src="data:image/png;base64,{base64_images[0]}" style="width: {image_size[1]}px; height: {image_size[0]}px;">
<span id="imageNumber_{slider_id}">Image 0/{len(base64_images) - 1}</span>
</div>
<script>
(function() {{
let base64_images = {base64_images};
let current_image = 0;
function updateImage() {{
let slider = document.getElementById("imageSlider_{slider_id}");
current_image = parseInt(slider.value);
let displayedImage = document.getElementById("displayedImage_{slider_id}");
displayedImage.src = "data:image/png;base64," + base64_images[current_image];
let imageNumber = document.getElementById("imageNumber_{slider_id}");
imageNumber.innerHTML = "Image " + current_image + "/{len(base64_images) - 1}";
}}
document.getElementById("imageSlider_{slider_id}").addEventListener("input", updateImage);
}})();
</script>
"""
display(HTML(html_code))
[docs]
def compute_colored_FOV(
spatialFootprints: List[scipy.sparse.csr_array],
FOV_height: int,
FOV_width: int,
labels: Optional[Union[List[np.ndarray], np.ndarray]] = None,
cmap: Union[str, object] = 'random',
alphas_labels: Optional[np.ndarray] = None,
alphas_sf: Optional[Union[List[np.ndarray], np.ndarray]] = None,
color_unlabeled: Optional[List[float]] = None,
) -> List[np.ndarray]:
"""
Computes a set of images of fields of view (FOV) of spatial footprints,
colored by the predicted class.
RH 2023
Args:
spatialFootprints (List[scipy.sparse.csr_array]):
Each element is all the spatial footprints for a given session.
FOV_height (int):
Height of the field of view.
FOV_width (int):
Width of the field of view.
labels (Optional[Union[List[np.ndarray], np.ndarray]]):
Label (will be a unique color) for each spatial footprint. Each
element is all the labels for a given session. If -1, then the
spatial footprint will be black / transparent. Can either be a list
of integer labels for each session, or a single array with all the
labels concatenated. Optional, if None, then all labels are set to
random colors.
cmap (Union[str, object]):
Colormap to use for the labels. If 'random', then a random colormap
is generated. Else, this is passed to
matplotlib.colors.ListedColormap. (Default is 'random')
alphas_labels (Optional[np.ndarray]):
Alpha value for each label. shape: *(n_labels,)* which is the same
as the number of unique labels len(np.unique(labels)). (Default is
``None``)
alphas_sf (Optional[Union[List[np.ndarray], np.ndarray]]):
Alpha value for each spatial footprint. Can either be a list of
alphas for each session, or a single array with all the alphas
concatenated. (Default is ``None``)
Returns:
(List[np.ndarray]):
rois_c_bySession_FOV (List[np.ndarray]):
List of images of fields of view (FOV) of spatial footprints,
colored by the predicted class.
"""
spatialFootprints = [spatialFootprints] if isinstance(spatialFootprints, np.ndarray) else spatialFootprints
## Check inputs
assert all([scipy.sparse.issparse(sf) for sf in spatialFootprints]), "spatialFootprints must be a list of scipy.sparse.csr_array"
n_roi = np.array([sf.shape[0] for sf in spatialFootprints], dtype=np.int64)
n_roi_cumsum = np.concatenate([[0], np.cumsum(n_roi)]).astype(np.int64)
n_roi_total = sum(n_roi)
def _fix_list_of_arrays(v):
if isinstance(v, np.ndarray) or (isinstance(v, list) and isinstance(v[0], (np.ndarray, list)) is False):
v = [v[b_l: b_u] for b_l, b_u in zip(n_roi_cumsum[:-1], n_roi_cumsum[1:])]
assert (isinstance(v, (list, util.JSON_List)) and isinstance(v[0], (np.ndarray, list, util.JSON_List))), "input must be a list of arrays or a single array of integers"
return v
if labels is None:
labels = [np.random.randint(0, 255, size=n) for n in n_roi]
labels = _fix_list_of_arrays(labels)
alphas_sf = _fix_list_of_arrays(alphas_sf) if alphas_sf is not None else None
labels_cat = np.concatenate(labels)
u = np.unique(labels_cat)
n_c = len(u)
if alphas_labels is None:
alphas_labels = np.ones(n_c)
alphas_labels = np.clip(alphas_labels, a_min=0, a_max=1)
assert len(alphas_labels) == n_c, f"len(alphas_labels)={len(alphas_labels)} != n_c={n_c}"
if alphas_sf is None:
alphas_sf = np.ones(len(labels_cat))
if isinstance(alphas_sf, list):
alphas_sf = np.concatenate(alphas_sf)
alphas_sf = np.clip(alphas_sf, a_min=0, a_max=1)
assert len(alphas_sf) == len(labels_cat), f"len(alphas_sf)={len(alphas_sf)} != len(labels_cat)={len(labels_cat)}"
h, w = FOV_height, FOV_width
rois = scipy.sparse.vstack(spatialFootprints)
rois_max = rois.max(axis=1).toarray().reshape(-1, 1)
rois_max[rois_max == 0] = np.nan
rois = rois.multiply(1.0/rois_max).power(1)
rois.data[np.isnan(rois.data)] = 0
if n_c > 1:
colors = helpers.rand_cmap(nlabels=n_c, verbose=False)(np.linspace(0.,1.,n_c, endpoint=True)) if cmap=='random' else cmap(np.linspace(0.,1.,n_c, endpoint=True))
colors = colors / colors.max(1, keepdims=True)
else:
colors = np.array([[0,0,0,0]])
if np.isin(-1, labels_cat):
colors[0] = [0, 0, 0, 0] if color_unlabeled is None else color_unlabeled
labels_squeezed = helpers.squeeze_integers(labels_cat)
labels_squeezed -= labels_squeezed.min()
rois_c = scipy.sparse.hstack([rois.multiply(colors[labels_squeezed, ii][:,None]) for ii in range(4)]).tocsr()
rois_c.data = np.minimum(rois_c.data, 1)
## apply alpha
rois_c = rois_c.multiply(alphas_labels[labels_squeezed][:,None] * alphas_sf[:,None]).tocsr()
## make session_bool
session_bool = util.make_session_bool(n_roi)
rois_c_bySessions = [rois_c[idx] for idx in session_bool.T]
rois_c_bySessions_FOV = [r.max(0).toarray().reshape(4, h, w).transpose(1,2,0)[:,:,:3] for r in rois_c_bySessions]
return rois_c_bySessions_FOV
[docs]
def crop_cluster_ims(ims: np.ndarray) -> np.ndarray:
"""
Crops the images to the smallest rectangle containing all non-zero pixels.
RH 2022
Args:
ims (np.ndarray):
Images to crop. (shape: *(n, H, W)*)
Returns:
(np.ndarray):
cropped_ims (np.ndarray):
Cropped images. (shape: *(n, H', W')*)
"""
ims_max = np.max(ims, axis=0)
z_im = ims_max > 0
z_where = np.where(z_im)
z_top = z_where[0].max()
z_bottom = z_where[0].min()
z_left = z_where[1].min()
z_right = z_where[1].max()
ims_copy = copy.deepcopy(ims)
im_out = ims_copy[:, max(z_bottom-1, 0):min(z_top+1, ims.shape[1]), max(z_left-1, 0):min(z_right+1, ims.shape[2])]
im_out[:,(0,-1),:] = 1
im_out[:,:,(0,-1)] = 1
return im_out
[docs]
def display_cropped_cluster_ims(
spatialFootprints: List[np.ndarray],
labels: np.ndarray,
FOV_height: int = 512,
FOV_width: int = 1024,
n_labels_to_display: int = 100,
) -> None:
"""
Displays the cropped cluster images.
RH 2023
Args:
spatialFootprints (List[np.ndarray]):
List of spatial footprints. Each footprint is a 2D array
representing one region. (shape of each footprint: *(H, W)*)
labels (np.ndarray):
Labels for each region of interest (ROI). (shape: *(n,)*)
FOV_height (int):
Height of the field of view. (Default is *512*)
FOV_width (int):
Width of the field of view. (Default is *1024*)
n_labels_to_display (int):
Number of labels to display. (Default is *100*)
"""
import scipy.sparse
labels_unique = np.unique(labels[labels>-1])
ROI_ims_sparse = scipy.sparse.vstack(spatialFootprints)
ROI_ims_sparse = ROI_ims_sparse.multiply(1.0 / np.maximum(ROI_ims_sparse.max(axis=1).toarray().reshape(-1, 1), util.SPARSE_NORMALIZATION_FLOOR)).tocsr()
labels_bool_t = scipy.sparse.vstack([scipy.sparse.csr_array(labels==u) for u in np.sort(np.unique(labels_unique))]).tocsr()
labels_bool_t = labels_bool_t[:n_labels_to_display]
def helper_crop_cluster_ims(ii):
idx = labels_bool_t[[ii]].indices
return np.concatenate(list(crop_cluster_ims(ROI_ims_sparse[idx].toarray().reshape(len(idx), FOV_height, FOV_width))), axis=1)
labels_sfCat = [helper_crop_cluster_ims(ii) for ii in range(labels_bool_t.shape[0])]
for sf in labels_sfCat[:n_labels_to_display]:
plt.figure(figsize=(40,1))
plt.imshow(sf, cmap='gray')
plt.axis('off')
[docs]
def select_region_scatterPlot(
data: np.ndarray,
images_overlay: Optional[np.ndarray] = None,
idx_images_overlay: Optional[np.ndarray] = None,
size_images_overlay: Optional[float] = None,
frac_overlap_allowed: float = 0.5,
image_overlay_raster_size: Optional[Tuple[int, int]] = None,
path: Optional[str] = None,
figsize: Tuple[int, int] = (300, 300),
alpha_points: float = 0.5,
size_points: float = 1,
color_points: Union[str, List[str]] = 'k',
) -> Tuple[Callable, object, str]:
"""
Selects a region of a scatter plot and returns the indices of the points in
that region.
Args:
data (np.ndarray):
Input data to create a scatterplot. The shape must be *(n_samples,
2)*.
images_overlay (np.ndarray, optional):
A 3D array of grayscale images or a 4D array of RGB images, where
the first dimension is the number of images. (Default is ``None``)
idx_images_overlay (np.ndarray, optional):
A vector of data indices corresponding to each image in
images_overlay. The shape must be *(n_images,)*. (Default is
``None``)
size_images_overlay (float, optional):
Size of each overlay image. The unit is relative to each axis. This
value scales the resolution of the overlay raster. (Default is
``None``)
frac_overlap_allowed (float, optional):
Fraction of overlap allowed between the selected region and the
overlay images. This is only used when size_images_overlay is
``None``. (Default is 0.5)
image_overlay_raster_size (Tuple[int, int], optional):
Size of the rasterized image overlay in pixels. If ``None``, the
size will be set to figsize. (Default is ``None``)
path (str, optional):
Temporary file path to save the selected indices. (Default is
``None``)
figsize (Tuple[int, int], optional):
Size of the figure in pixels. (Default is (300, 300))
alpha_points (float, optional):
Alpha value of the scatter plot points. (Default is 0.5)
size_points (float, optional):
Size of the scatter plot points. (Default is 1)
color_points (Union[str, List[str]], optional):
Color of the scatter plot points. Single color only.
Returns:
(Tuple[Callable, object, str]): tuple containing:
fn_get_indices (Callable):
Function that returns the indices of the selected points.
layout (object):
Holoviews layout object.
path_tempfile (str):
Path to the temporary file that saves the selected indices.
Example:
.. highlight:: python
.. code-block:: python
fn_get_indices, layout, path_tempfile = select_region_scatterPlot(data)
"""
import holoviews as hv
import numpy as np
import tempfile
try:
from IPython.display import display
except:
print('Warning: Could not import IPython.display. Cannot display plot.')
return None, None
hv.extension('bokeh')
assert isinstance(data, np.ndarray), 'data must be a numpy array'
assert data.ndim == 2, 'data must have 2 dimensions'
assert data.shape[1] == 2, 'data must have 2 columns'
## Ingest inputs
if images_overlay is not None:
assert isinstance(images_overlay, np.ndarray), 'images_overlay must be a numpy array'
assert (images_overlay.ndim == 3) or (images_overlay.ndim == 4), 'images_overlay must have 3 or 4 dimensions'
assert images_overlay.shape[0] == idx_images_overlay.shape[0], 'images_overlay must have the same number of images as idx_images_overlay'
if image_overlay_raster_size is None:
image_overlay_raster_size = figsize
# Declare some points, set alpha, size, color
points = hv.Points(data)
points.opts(
alpha=alpha_points,
size=size_points,
color=color_points,
)
# Declare points as source of selection stream
selection = hv.streams.Selection1D(source=points)
path_tempFile = os.path.join(tempfile.gettempdir(), 'indices.csv') if path is None else path
# Write function that uses the selection indices to slice points and compute stats
def callback(index):
## Save the indices to a temporary file.
## First delete the file if it already exists.
if os.path.exists(path_tempFile):
os.remove(path_tempFile)
## Then save the indices to the file. Open in a protected way that blocks other threads from opening it
with open(path_tempFile, 'w') as f:
f.write(','.join([str(i) for i in index]))
return points
selection.param.watch_values(callback, 'index')
layout = points.opts(
tools=['lasso_select', 'box_select'],
width=figsize[0],
height=figsize[1],
)
# If images are provided, overlay them on the points
def norm_img(image):
"""
Normalize 2D grayscale image
"""
normalized_image = (image - np.min(image)) / np.max(image)
return normalized_image
if images_overlay is not None and idx_images_overlay is not None:
min_emb = np.nanmin(data, axis=0) ## shape (2,)
max_emb = np.nanmax(data, axis=0) ## shape (2,)
range_emb = max_emb - min_emb ## shape (2,)
aspect_ratio_ims = (range_emb[1] / range_emb[0]) ## shape (1,)
lims_canvas = ((min_emb - range_emb*0.05), (max_emb + range_emb*0.05)) ## ( shape (2,)(mins), shape (2,)(maxs) )
range_canvas = lims_canvas[1] - lims_canvas[0] ## shape (2,)
n_ims = images_overlay.shape[0] if images_overlay is not None else 0
if size_images_overlay is None:
import sklearn
min_image_distance = sklearn.neighbors.NearestNeighbors(
n_neighbors=2,
algorithm='auto',
metric='euclidean'
).fit(
data[idx_images_overlay]
).kneighbors_graph(
data[idx_images_overlay],
n_neighbors=2,
mode='distance'
)
min_image_distance.eliminate_zeros()
min_image_distance = np.nanmin(min_image_distance.data)
size_images_overlay = float(min_image_distance) * (1 + frac_overlap_allowed)
print(f'Using size_images_overlay = {size_images_overlay}')
assert isinstance(size_images_overlay, (int, float, np.ndarray)), 'size_images_overlay must be an int, float, or shape (2,) numpy array'
if isinstance(size_images_overlay, (int, float)):
size_images_overlay = np.array([size_images_overlay / aspect_ratio_ims, size_images_overlay])
assert size_images_overlay.shape == (2,), 'size_images_overlay must be an int, float, or shape (2,) numpy array'
# Create a large canvas to hold all the images
iors = image_overlay_raster_size
canvas = np.zeros((iors[0], iors[1],4))
interp_0 = scipy.interpolate.interp1d(
x=np.linspace(lims_canvas[0][0], lims_canvas[1][0], num=iors[0], endpoint=False),
y=np.linspace(0,iors[0],num=iors[0], endpoint=False),
)
interp_1 = scipy.interpolate.interp1d(
x=np.linspace(lims_canvas[0][1], lims_canvas[1][1], num=iors[1], endpoint=False),
y=np.linspace(0,iors[1],num=iors[1], endpoint=False),
)
for image, idx in zip(images_overlay, idx_images_overlay):
sz_im_0 = int((size_images_overlay[0] / range_canvas[0]) * iors[0])
sz_im_1 = int((size_images_overlay[1] / range_canvas[1]) * iors[1])
im_interp = scipy.interpolate.RegularGridInterpolator(
points=(
np.linspace(0, images_overlay.shape[1], num=images_overlay.shape[1], endpoint=False),
np.linspace(0, images_overlay.shape[2], num=images_overlay.shape[2], endpoint=False),
),
values=image,
bounds_error=False,
fill_value=0,
)(np.stack(np.meshgrid(
np.linspace(0, images_overlay.shape[1], num=sz_im_0, endpoint=False),
np.linspace(0, images_overlay.shape[2], num=sz_im_1, endpoint=False),
), axis=-1))
if im_interp.size == 0:
warnings.warn(f'Image {idx} is empty after interpolation. Skipping. Increase size_images_overlay.')
image_rgb = np.stack([norm_img(im_interp), norm_img(im_interp), norm_img(im_interp)], axis=-1) if im_interp.ndim == 2 else im_interp
x1 = int(interp_0(data[idx,0]) - sz_im_0 / 2)
y1 = int(interp_1(data[idx,1]) - sz_im_1 / 2)
x2 = int(interp_0(data[idx,0]) + sz_im_0 / 2)
y2 = int(interp_1(data[idx,1]) + sz_im_1 / 2)
assert x1 >= 0 and x2 <= iors[0] and y1 >= 0 and y2 <= iors[1], f'Image is out of bounds of canvas: y1={y1}, y2={y2}, x1={x1}, x2={x2}, sz_im_0={sz_im_0}, sz_im_1={sz_im_1}, iors={iors}'
canvas[y1:y2, x1:x2,:3] = image_rgb
canvas[y1:y2, x1:x2,3] = 1
canvas = np.flipud(canvas)
# Now create a single hv.RGB object
imo = hv.RGB(canvas, bounds=(lims_canvas[0][0], lims_canvas[0][1], lims_canvas[1][0], lims_canvas[1][1]))
## Set bounds of the plot
layout = layout.redim.range(x=(lims_canvas[0][0], lims_canvas[1][0]), y=(lims_canvas[0][1], lims_canvas[1][1]))
layout *= imo
## start layout with lasso tool active
layout = layout.opts(
active_tools=[
'lasso_select',
'wheel_zoom',
]
)
# Display plot
display(layout)
def fn_get_indices():
if os.path.exists(path_tempFile):
with open(path_tempFile, 'r') as f:
indices = f.read().split(',')
indices = [int(i) for i in indices if i != ''] if len(indices) > 0 else None
return indices
else:
return None
return fn_get_indices, layout, path_tempFile
[docs]
def get_spread_out_points(
data: np.ndarray,
n_ims: int = 1000,
dist_im_to_point: float = 0.3,
border_frac: float = 0.05,
device: str = 'cpu',
) -> np.ndarray:
"""
Given a set of points, returns the indices of a subset of points that are
spread out. Intended to be used to overlay images on a scatter plot of
points.
RH 2023
Args:
data (np.ndarray):
Array containing the points to be spread out. Shape: *(N, 2)*
n_ims (int):
Number of indices to return corresponding to the number of images to
be displayed. (Default is *1000*)
dist_im_to_point (float):
Minimum distance between an image and its nearest point. Images with
a minimum distance to a point greater than this value will be
discarded. (Default is *0.3*)
border_frac (float):
Fraction of the range of the data to add as a border around the
points. (Default is *0.05*)
device (str):
Device to use for torch operations. (Default is 'cpu')
Returns:
(np.ndarray):
idx_images_overlay (np.ndarray):
Array containing the indices of the points to overlay images on.
Shape: *(n_ims,)*
"""
import torch
DEVICE = device
min_data = np.nanmin(data, axis=0) ## shape (2,)
max_data = np.nanmax(data, axis=0) ## shape (2,)
range_data = max_data - min_data ## shape (2,)
lims_canvas = ((min_data - range_data*border_frac), (max_data + range_data*border_frac)) ## ([
sz_im = (range_data / (n_ims**0.5))
grid_canvas = np.meshgrid(
np.linspace(lims_canvas[0][0], lims_canvas[1][0], int(n_ims**0.5)),
np.linspace(lims_canvas[0][1], lims_canvas[1][1], int(n_ims**0.5)),
indexing='xy',
)
grid_canvas_flat = np.vstack([g.reshape(-1) for g in grid_canvas]).T
dist_grid_to_imIdx = torch.as_tensor(data, device=DEVICE, dtype=torch.float32)[:,None,:] - \
torch.as_tensor(grid_canvas_flat, device=DEVICE, dtype=torch.float32)[None,:,:]
distNorm_grid_to_imIdx = torch.linalg.norm(dist_grid_to_imIdx, dim=2)
distMin_grid_to_imIdx = torch.min(distNorm_grid_to_imIdx, dim=0)
max_dist = (np.min(sz_im))*dist_im_to_point
idx_good = distMin_grid_to_imIdx.values < max_dist
idx_images_overlay = distMin_grid_to_imIdx.indices[idx_good]
return idx_images_overlay
[docs]
def display_labeled_ROIs(
images: np.ndarray,
labels: Union[np.ndarray, Dict[str, Any]],
max_images_per_label: int = 10,
figsize: Tuple[int, int] = (10, 3),
fontsize: int = 25,
shuffle: bool = True,
) -> None:
"""
Displays a grid of images, each row corresponding to a label, and each image
is a randomly selected image from that label.
RH 2023
Args:
images (np.ndarray):
Array of images. Shape: *(num_images, height, width)* or
*(num_images, height, width, num_channels)*
labels (Union[np.ndarray, Dict[str, Any]]):
If dict, it must contain keys 'index' and 'label', where 'index' is
an array (or list) of indices corresponding to the indices of the
images, and 'label' is an array (or list) of labels with the same
length as 'index'. If ndarray, it must be a 1D array of labels
corresponding to each image.
max_images_per_label (int):
Maximum number of images to display per label. (Default is *10*)
figsize (Tuple[int, int]):
Size of the figure. (Default is *(10, 3)*)
fontsize (int):
Font size of the labels. (Default is *25*)
shuffle (bool):
If ``True``, the order of the images will be shuffled. (Default is
``True``)
"""
import random
if isinstance(labels, (np.ndarray, list)):
print(f'labels is a {type(labels)}. Converting to a labels_dict by assuming that image indices are the same as the indices in labels.')
labels_dict = {
'index': np.arange(len(labels)),
'label': labels,
}
elif isinstance(labels, dict):
labels_dict = {
'index': np.array(labels['index'], dtype=np.int64),
'label': np.array(labels['label']),
}
elif isinstance(labels, pd.DataFrame):
labels_dict = {
'index': np.array(labels.index, dtype=np.int64),
'label': np.array(labels['label']),
}
else:
raise Exception(f'labels must be a list, np.ndarray, or dict. Got {type(labels)}.')
for l in np.unique(labels_dict['label']):
idx_l = np.where(labels_dict['label']==l)[0]
idx_l = random.sample(list(idx_l), len(idx_l)) if shuffle else idx_l
n_l = min(len(idx_l), max_images_per_label)
fig, axs = helpers.plot_image_grid(
images=images[labels_dict['index'][idx_l]],
# images=images[idx_l],
labels=labels_dict['index'][idx_l],
grid_shape=(1, n_l),
kwargs_subplots={'figsize': figsize}
);
fig.text(0,0.4, l, fontdict={'size': fontsize});
[docs]
def plot_confusion_matrix(
confusion_matrix,
class_names: List[str] = None,
figsize: Tuple[int, int] = (4, 4),
n_decimals: int = 2,
):
"""
Plots a confusion matrix using seaborn.
RH 2023
Args:
confusion_matrix (np.ndarray):
Array containing the confusion matrix. Shape: *(num_classes,
num_classes)*
class_names (list):
List of class names. Length: *num_classes* If ``None``, the class
names will be the indices of the confusion matrix.
figsize (Tuple[int, int]):
Size of the figure.
n_decimals (int):
Number of decimals to round the confusion matrix to.
"""
import seaborn as sns
import matplotlib.pyplot as plt
## Make plot
fig = plt.figure(figsize=figsize)
heatmap = sns.heatmap(
np.round(confusion_matrix, decimals=n_decimals),
annot=True,
annot_kws={"size": figsize[0]*4},
vmin=0.,
vmax=1.,
cmap=plt.get_cmap('gray'),
)
## Remove colormap
plt.gca().collections[0].colorbar.remove()
## Set tick labels
class_names = class_names if class_names is not None else np.arange(confusion_matrix.shape[0])
heatmap.yaxis.set_ticklabels(
class_names,
rotation=0,
ha='right',
fontsize=figsize[0]*3
)
heatmap.xaxis.set_ticklabels(
class_names,
rotation=45,
ha='right',
fontsize=figsize[0]*3
)
plt.ylabel('True label', fontdict={'size': figsize[0]*2})
plt.xlabel('Predicted label', fontdict={'size': figsize[0]*2})