Source code for darsia.multi_image_analysis.imageregistration

"""Module containing a diffeomorphic image registration tool.

"""

from __future__ import annotations

import time
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import skimage

import darsia


[docs] class DiffeomorphicImageRegistration: """ Class to detect the deformation between different images. After all, DiffeomorphicImageRegistration is a wrapper using TranslationAnalysis. """ def __init__(self, img_dst: darsia.Image, **kwargs) -> None: """Constructor for DiffeomorphicImageRegistration. Args: dst (darsia.Image): reference image which is supposed to be fixed in the analysis, serves as destination object. optional keyword arguments: N_patches (list of two int): number of patches in x and y direction rel_overlap (float): relative overlap in each direction, related to the patch size max_features (int) maximal number of features in thefeature detection tol (float): tolerance mask (np.ndarray, optional): roi in which features are considered. """ # Create translation estimator max_features = kwargs.get("max_features", 200) tol = kwargs.get("tol", 0.05) self.translation_estimator = darsia.TranslationEstimator(max_features, tol) # Create translation analysis tool, and use the baseline image as reference point self.N_patches = kwargs.get("N_patches", [1, 1]) self.rel_overlap = kwargs.get("rel_overlap", 0.0) mask_dst: Optional[darsia.Image] = kwargs.get("mask_dst", None) self.translation_analysis = darsia.TranslationAnalysis( img_dst, N_patches=self.N_patches, rel_overlap=self.rel_overlap, translation_estimator=self.translation_estimator, mask=mask_dst, )
[docs] def update_dst(self, img_dst: darsia.Image) -> None: """ Update of dst image. Args: dst (np.ndarray): image array """ self.translation_analysis.update_base(img_dst)
[docs] def deduct( self, diffeomorphic_image_registration ) -> None: #: DiffeomorphicImageRegistration """ Effectviely copy from external DiffeomorphicImageRegistration. Args: diffeomorphic_image_registration (darsia.DiffeomorphicImageRegistration): Diffeomorphic image registration object holding a translation analysis. """ # The displacement is stored in the translation analysis as callable. # Thus, the current translation analysis has to be updated. self.translation_analysis.deduct_translation_analysis( diffeomorphic_image_registration.translation_analysis )
[docs] def add( self, diffeomorphic_image_registration ) -> None: #: DiffeomorphicImageRegistration """ Update the store translation by adding the translation of an external diffeomorphic image registration. Args: diffeomorphic_image_registration (darsia.DiffeomorphicImageRegistration): Diffeomorphic image registraton object holding a translation analysis. """ # The displacement is stored in the translation analysis as callable. # Thus, the current translation analysis has to be updated. self.translation_analysis.add_translation_analysis( diffeomorphic_image_registration.translation_analysis )
def __call__( self, img: darsia.Image, mask: Optional[np.ndarray] = None, return_transformed_dst: bool = False, ) -> darsia.Image: """ Image registration routine. Args: img (darsia.Image): test image mask (np.ndaray): active mask return_transformed_dst (bool): flag whether the transform is also applied to dst (inversely) Returns: darsia.Image: transformed test image darsia.Image: transformed reference image #TODO? """ transformed_img = self.translation_analysis(img, mask=mask) # Return results if return_transformed_dst: transformed_dst = self.apply(self.dst, reverse=True) return transformed_img, transformed_dst else: return transformed_img
[docs] def call_with_output( self, img: darsia.Image, plot_patch_translation: bool = False, return_patch_translation: bool = False, mask: Optional[darsia.Image] = None, ) -> darsia.Image: """ Determine the deformation pattern and apply diffeomorphism to the image aiming at matching the reference/destination (dst) image. This in the end only a wrapper for the translation analysis. Args: img (darsia.Image): test image reverse (bool): flag whether the translation is understood as from the test image to the dst image, or reversed. The default is the former. plot_patch_translation (bool): flag controlling whether the displacement is also visualized as vector field. return_patch_translation (bool): flag controlling whether the displacement in the patch centers is returned in the sense of img to dst, complying to the plot; default is False. """ transformed_img = self.translation_analysis(img, mask=mask) if return_patch_translation: patch_translation = self.translation_analysis.return_patch_translation( False ) if plot_patch_translation: self.translation_analysis.plot_translation(False) if return_patch_translation: return transformed_img, patch_translation else: return transformed_img
[docs] def evaluate( self, coords: Union[np.ndarray, darsia.Patches], reverse: bool = False, units: str = "metric", ) -> np.ndarray: """ Evaluate diffeormorphism in arbitrary points. Args: coords (np.ndarray, or darsia.Patches): coordinate array with shape num_pts x 2, or alternatively num_rows_pts x num_cols_pts x 2, identifying points in a mesh/patched image, or equivalently patch. reverse (bool): flag whether the translation is understood as from the test image to the baseline image, or reversed. The default is the former latter. units (str): input and output units; "metric" default; otherwise assumed to be "pixel". Returns: np.ndarray: deformation vectors for all coordinates. """ # Reshape coords using a num_pts x 2 format. if isinstance(coords, darsia.Patches): coords = coords.global_centers_cartesian assert coords.shape[-1] == 2 coords_shape = coords.shape coords = coords.reshape(-1, 2) # Convert coordinates to pixels with reverse matrix indexing, # if provided in metric units. if units == "metric": base = self.translation_analysis.base pixel_coords = base.coordinatesystem.voxel(coords) pixel_coords = np.fliplr(pixel_coords) # Reverse matrix ind. elif units == "pixel": pixel_coords = coords else: raise ValueError # Interpolate at provided values - expect reverse matrix indexing translation = self.translation_analysis.translation(pixel_coords) # Flip, if required if reverse: translation *= -1.0 # Collect results, use ordering of components consistent with matrix # indexing (i.e. flip of components needed) displacement = np.transpose(np.vstack((translation[1], translation[0]))) # Convert to metric units if required; for pixels, use matrix indexing. if units == "metric": displacement = base.coordinatesystem.coordinate_vector(displacement) # Reshape to format used at input return displacement.reshape(coords_shape)
[docs] def apply(self, img: darsia.Image, reverse: bool = False) -> darsia.Image: """ Apply computed transformation onto arbitrary image. Args: img (np.ndarray or darsia.Image): image reverse (bool): flag whether the translation is understood as from the test image to the baseline image, or reversed. The default is the latter. Returns: np.ndarray, optional: transformed image, if input is array; no output otherwise """ # Load the image into translation_analysis self.translation_analysis.load_image(img) # Apply the transformation, stored in translation_analysis return self.translation_analysis.translate_image(reverse)
[docs] def plot( self, scaling: float = 1.0, mask: Optional[darsia.Image] = None, ) -> None: """ Plots diffeomorphism. Args: scaling (float): scaling for vectors. mask (darsia.Image, optional): active set. """ # Wrapper for translation_analysis. self.translation_analysis.plot_translation( reverse=False, scaling=scaling, mask=mask )
[docs] def displacement(self) -> np.ndarray: """ Return displacement in metric units on all pixels. """ # Define coordinates for each pixel Ny, Nx = self.translation_analysis.base.img.shape[:2] # Nz = 1 x = np.arange(Nx) y = np.arange(Ny) X_pixel, Y_pixel = np.meshgrid(x, y) # Transform coordinates into the right format (vector) pixel_vector = np.transpose(np.vstack((np.ravel(X_pixel), np.ravel(Y_pixel)))) # Reshape coords using a num_pts x 2 format. pixel_coords = pixel_vector.reshape(-1, 2) # Interpolate at provided values - expect reverse matrix indexing tic = time.time() translation = self.translation_analysis.translation(pixel_coords) print(f"translation evaluation: {time.time() - tic}") # # Flip, if required # if reverse: # translation *= -1.0 # results, use ardering of components consistent with matrix indexing displacement = np.transpose(np.vstack((translation[1], translation[0]))) # Convert to metric units tic = time.time() displacement = ( self.translation_analysis.base.coordinatesystem.coordinate_vector( displacement ) ) print(f"coordinate system: {time.time() - tic}") # Cache self.displacement = displacement return displacement
[docs] class MultiscaleDiffeomorphicImageRegistration: """ Class for multiscale diffeomorphic image registration being capable of tracking larger deformations. """ def __init__( self, img_dst: darsia.Image, config: Union[dict, list[dict]], mask_dst: Optional[np.ndarray] = None, total_config: Optional[dict] = None, **kwargs, ) -> None: """ Args: img_dst (darsia.Image): reference image which is supposed to be fixed in the analysis, serves as destination object. config (list of config): hierrachy of config dictionaries. mask_dst (np.ndarray): active mask total_config (dict): parameters for image registration for the overall image registration. """ # Cache inputs self.img_dst = img_dst self.mask_dst = mask_dst assert isinstance(config, list) and all( [isinstance(config[i], dict) for i in range(len(config))] ) self.config = config self.num_levels = len(self.config) # Make sure to have a fine config for the final image registration self.total_config = self.config[-1] if total_config is None else total_config # Cache verbosity self.verbosity = kwargs.get("verbosity", 0) def __call__( self, img: darsia.Image, mask: Optional[np.ndarray] = None, return_transformed_dst: bool = False, ) -> darsia.Image: # TODO typing. """ Image registration routine. Args: img (Image): test image mask (np.ndaray): active mask return_transformed_dst (bool): flag whether the transform is also applied to dst (inversely) Returns: Image: transformed test image """ # Store inputs transformed_img = img.copy() transformed_mask = ( np.ones(img.img.shape[:2], dtype=bool) if mask is None else mask.copy() ) # Initialize combined image registration self.combined_image_registration = DiffeomorphicImageRegistration( transformed_img, **self.total_config ) # Multi level approach succesively updating the combined image registration for level in range(self.num_levels): # Determine deformation for current level _, image_registration = self._single_level_iteration( transformed_img, transformed_mask, self.config[level] ) # Update inputs if level == 0: self.combined_image_registration.deduct(image_registration) else: self.combined_image_registration.add(image_registration) transformed_mask = image_registration.apply(transformed_mask) transformed_img = self.combined_image_registration.apply(img) # Return results if return_transformed_dst: transformed_dst = self.combined_image_registration.apply( self.dst, reverse=True ) return transformed_img, transformed_dst else: return transformed_img def _single_level_iteration( self, img: darsia.Image, mask: np.ndarray, config: dict, ) -> tuple[darsia.Image, DiffeomorphicImageRegistration]: # TODO typing """One iteration of multiscale image registration. Args: img (darsia.Image): test image mask (np.ndarray): active mask config (dict): parameters for image registration Returns: darsia.Image: transformed image darsia.DiffeomorphicImageRegistration: resulting image registration """ # Find image registration image_registration = DiffeomorphicImageRegistration( self.img_dst, mask=self.mask_dst, **config ) plot = self.verbosity >= 2 transformed_img, patch_translation = image_registration.call_with_output( img, plot_patch_translation=plot, return_patch_translation=True, mask=mask ) if self.verbosity >= 2: plt.figure("comparison") plt.imshow( skimage.util.compare_images( self.img_dst.img, transformed_img.img, method="blend" ) ) plt.show() return transformed_img, image_registration
[docs] def apply(self, img: darsia.Image, reverse: bool = False) -> darsia.Image: """ Apply computed transformation onto arbitrary image. Args: img (np.ndarray or darsia.Image): image reverse (bool): flag whether the translation is understood as from the test image to the baseline image, or reversed. The default is the latter. Returns: np.ndarray, optional: transformed image, if input is array; no output otherwise """ if not hasattr(self, "combined_image_registration"): raise ValueError("Construct the deformation first.") return self.combined_image_registration.apply(img, reverse)
[docs] def plot( self, scaling: float, mask: Optional[darsia.Image] = None, ) -> None: """ Plot the dislacement stored in the current image registration. Args: scaling (float): scaling parameter to controll the length of the arrows. mask (np.ndarray): active mask """ if not hasattr(self, "combined_image_registration"): raise ValueError("Construct the deformation first.") self.combined_image_registration.plot(scaling=scaling, mask=mask)
[docs] def evaluate( self, coords: Union[np.ndarray, darsia.Patches], reverse: bool = False, units: str = "metric", ) -> np.ndarray: """See evaluate in DiffeomorphicImageRegistration.""" return self.combined_image_registration.evaluate(coords, reverse, units)
# ! ---- Administrator of Image Registration algorithms
[docs] class ImageRegistration: def __init__( self, img_dst: darsia.Image, method: Optional[str] = None, **kwargs, ) -> None: """Constructor for DiffeomorphicImageRegistration. Args: dst (darsia.Image): reference image which is supposed to be fixed in the analysis, serves as destination object. optional keyword arguments: N_patches (list of two int, or lits of such): number of patches in x and y direction rel_overlap (float, or list of such): relative overlap in each direction, related to the patch size max_features (int, or list of such) maximal number of features in the feature detection tol (float, or list of such): tolerance mask (np.ndarray, optional): roi in which features are considered. """ assert method in [None, "multilevel", "onelevel"] # Fetch keyword arguments N_patches = kwargs.get("N_patches", [1, 1]) max_features = kwargs.get("max_features", 200) tol = kwargs.get("tol", 0.05) rel_overlap = kwargs.get("rel_overlap", 0.0) mask_dst: Optional[darsia.Image] = kwargs.get("mask_dst", None) verbosity = kwargs.get("verbosity", 0) # Method provided through N_patches if method is None: method = ( "multilevel" if isinstance(N_patches, list) and all([isinstance(N_patches[i], list) for i in range(len(N_patches))]) else "onelevel" ) if method == "multilevel": if not isinstance(N_patches, list): N_patches = [N_patches] num_levels = len(N_patches) # Check compatibility compatibility = True if isinstance(max_features, int): max_features = num_levels * [max_features] else: compatibility = ( compatibility and isinstance(max_features, list) and len(max_features) == num_levels ) if isinstance(tol, float): tol = num_levels * [tol] else: compatibility = ( compatibility and isinstance(tol, list) and len(tol) == num_levels ) if isinstance(rel_overlap, float): rel_overlap = num_levels * [rel_overlap] else: compatibility = ( compatibility and isinstance(rel_overlap, list) and len(rel_overlap) == num_levels ) if not compatibility: raise ValueError( "Input for the multilevel image registration is not compatible." ) # Prepare hierarchy of config dictionaries config = [ { "N_patches": N_patches[i], "max_features": max_features[i], "tol": tol[i], "rel_overlap": rel_overlap[i], } for i in range(num_levels) ] self.image_registration = MultiscaleDiffeomorphicImageRegistration( img_dst, config, mask_dst, verbosity=verbosity ) elif method == "onelevel": # Extract possibly the first value from list, if list is provided if isinstance(N_patches, list) and all( [isinstance(N_patches[i], list) for i in range(len(N_patches))] ): N_patches = N_patches[0] if isinstance(max_features, list): max_features = max_features[0] if isinstance(tol, list): tol = tol[0] if isinstance(rel_overlap, list): rel_overlap = rel_overlap[0] # Construct feasible config config = { "N_patches": N_patches, "max_features": max_features, "tol": tol, "rel_overlap": rel_overlap, "mask": mask_dst, "verbosity": verbosity, } self.image_registration = DiffeomorphicImageRegistration(img_dst, **config) else: raise NotImplementedError( f"Method {method} is not implemented for ImageRegistration." ) def __call__( self, img: darsia.Image, mask: Optional[np.ndarray] = None, return_transformed_dst: bool = False, ) -> darsia.Image: return self.image_registration(img, mask, return_transformed_dst)
[docs] def apply(self, img: darsia.Image, reverse: bool = False) -> darsia.Image: """ Apply computed transformation onto arbitrary image. Args: img (np.ndarray or darsia.Image): image reverse (bool): flag whether the translation is understood as from the test image to the baseline image, or reversed. The default is the latter. Returns: np.ndarray, optional: transformed image, if input is array; no output otherwise """ return self.image_registration.apply(img, reverse)
[docs] def plot(self, scaling: float, mask: Optional[darsia.Image] = None) -> None: """ Plot the dislacement stored in the current image registration. Args: scaling (float): scaling parameter to controll the length of the arrows. mask (np.ndarray): active mask """ self.image_registration.plot(scaling=scaling, mask=mask)
[docs] def evaluate( self, coords: Union[np.ndarray, darsia.Patches], reverse: bool = False, units: str = "metric", ) -> np.ndarray: """See evaluate in DiffeomorphicImageRegistration.""" return self.image_registration.evaluate(coords, reverse, units)