Source code for darsia.signals.models.staticthresholdmodel
"""Module converting signals to binary data by applying thresholding.
A distinction between heterogeneous and homogeneous thresholding
is performed automatically.
"""
from __future__ import annotations
from typing import Optional, Union
import numpy as np
import skimage
import darsia
[docs]
class StaticThresholdModel(darsia.Model):
"""
Class for static thresholding.
"""
def __init__(
self,
threshold_lower: Union[float, list[float]] = 0.0,
threshold_upper: Optional[Union[float, list[float]]] = None,
labels: Optional[np.ndarray] = None,
return_float: bool = False,
) -> None:
"""
Constructor of StaticThresholdModel.
Args:
threshold_lower (float or list of float): lower threshold value(s)
threshold_upper (float or list of float): upper threshold value(s)
labels (array): labeled domain
return_float (bool): flag controlling whether the output is a float or boolean
"""
self.return_float = return_float
# The argument label decides whether a homogeneous or heterogeneous
# threatment is considered.
if labels is None:
# Homogeneous case
self._is_homogeneous = True
assert isinstance(threshold_lower, float)
assert isinstance(threshold_upper, float) or threshold_upper is None
self._threshold_lower = threshold_lower
self._threshold_upper = threshold_upper
self.num_parameters = 2
else:
# Heterogeneous case
self._is_homogeneous = False
self._labels = labels
num_labels = len(np.unique(self._labels))
if isinstance(threshold_lower, list) or isinstance(
threshold_lower, np.ndarray
):
# Allow for heterogeneous initial value.
assert len(threshold_lower) == num_labels
self._threshold_lower = np.array(threshold_lower)
elif isinstance(threshold_lower, float):
# Or initialize all labels with the same value
self._threshold_lower = threshold_lower * np.ones(
num_labels, dtype=float
)
else:
raise ValueError(f"Type {type(threshold_lower)} not supported.")
if isinstance(threshold_upper, list) or isinstance(
threshold_upper, np.ndarray
):
# Allow for heterogeneous initial value.
assert len(threshold_upper) == num_labels
self._threshold_upper = np.array(threshold_upper)
elif isinstance(threshold_upper, float):
# Or initialize all labels with the same value
self._threshold_upper = threshold_upper * np.ones(
num_labels, dtype=float
)
elif threshold_upper is None:
self._threshold_upper = None
else:
raise ValueError(f"Type {type(threshold_upper)} not supported.")
self.num_parameters = 2 * num_labels
def __call__(
self, img: np.ndarray, mask: Optional[np.ndarray] = None
) -> np.ndarray:
"""
Convert signal to binary data through thresholding.
Args:
img (np.ndarray): signal
mask (np.ndarray, optional): mask
Returns:
np.ndarray: boolean mask
"""
# Apply thresholding directly to the signal
if self._is_homogeneous:
threshold_mask = self._call_homogeneous(img)
else:
threshold_mask = self._call_heterogeneous(img)
# Restrict data to the provided mask
if mask is None:
if self.return_float:
return skimage.img_as_float32(threshold_mask)
else:
return threshold_mask
else:
return np.logical_and(threshold_mask, mask)
def _call_homogeneous(self, img: np.ndarray) -> np.ndarray:
"""
Convert signal to binary data through thresholding, tailored for the homogeneous case.
Args:
img (np.ndarray): signal
Returns:
np.ndarray: boolean mask
"""
if self._threshold_upper is not None:
return np.logical_and(
img > self._threshold_lower, img < self._threshold_upper
)
else:
return img > self._threshold_lower
def _call_heterogeneous(self, img: np.ndarray) -> np.ndarray:
"""
Convert signal to binary data through thresholding, tailored for the
heterogeneous case.
Args:
img (np.ndarray): signal
Returns:
np.ndarray: boolean mask
"""
threshold_mask = np.zeros(self._labels.shape[:2], dtype=bool)
for i, label in enumerate(np.unique(self._labels)):
threshold_mask_i = img > self._threshold_lower[i]
if self._threshold_upper is not None:
threshold_mask_i = np.logical_and(
threshold_mask_i, img < self._threshold_upper[i]
)
roi = np.logical_and(threshold_mask_i, self._labels == label)
threshold_mask[roi] = True
return threshold_mask
[docs]
def update_model_parameters(
self,
*args: tuple,
) -> None:
raise NotImplementedError(
"StaticThresholdModel does not support parameter updates."
)