"""Module containing objects useful for correcting images wrt rotations
"""
from __future__ import annotations
import itertools
from pathlib import Path
from typing import Union
import numpy as np
from scipy.spatial.transform import Rotation
import darsia
[docs]
class RotationCorrection(darsia.BaseCorrection):
"""Rotation correction.
Rotations are defined as combination of multiple basic rotations. In 2d, a single
basic rotation is sufficient. In 3d, although three are available, two are
sufficient.
Attributes:
dim (int): ambient dimension
anchor (array or list): voxel coordinates of anchor
rotation (array): rotation matrix
rotation_inv (array): inverted rotation matrix
"""
def __init__(
self,
anchor: Union[list[int], np.ndarray],
**kwargs,
) -> None:
# Cache anchor of rotation
self.anchor = np.array(anchor)
# Cache dimension (obtained from anchor)
dim = len(self.anchor)
self.dim = dim
rotation_from_isometry = kwargs.get("rotation_from_isometry", False)
if rotation_from_isometry:
pts_src = kwargs.get("pts_src")
pts_dst = kwargs.get("pts_dst")
affine_map = darsia.AffineTransformation(
pts_src,
pts_dst,
)
else:
rotations = kwargs.get("rotations")
if rotations is None:
raise ValueError("No means provided to determine rotations.")
# Define rotation as combination of basic rotations
if dim == 2:
if rotation_from_isometry:
self.rotation = affine_map.rotation
self.rotation_inv = np.linalg.inv(affine_map.rotation)
else:
degree = rotations[0]
vector = np.array([0, 0, 1])
rotation = Rotation.from_rotvec(degree * vector)
self.rotation = rotation.as_matrix()[:2, :2]
rotation_inv = Rotation.from_rotvec(-degree * vector)
self.rotation_inv = rotation_inv.as_matrix()[:2, :2]
elif dim == 3:
if rotation_from_isometry:
self.rotation = affine_map.rotation
self.rotation_inv = np.linalg.inv(affine_map.rotation)
else:
self.rotation = np.eye(dim)
self.rotation_inv = np.eye(dim)
for degree, cartesian_axis in rotations:
indexing = "xyz"[:dim]
matrix_axis, reverted = darsia.interpret_indexing(
cartesian_axis, indexing
)
vector = np.eye(dim)[matrix_axis]
scaling = -1 if reverted else 1
rotation = Rotation.from_rotvec(scaling * degree * vector)
self.rotation = np.matmul(self.rotation, rotation.as_matrix())
rotation_inv = Rotation.from_rotvec(-degree * vector)
self.rotation_inv = np.matmul(
self.rotation_inv, rotation_inv.as_matrix()
)
[docs]
def correct_array(self, img: np.ndarray) -> np.ndarray:
"""Main routine: Application of inherent rotation to provided image.
Args:
img (array): image
Returns:
array: rotated image
"""
# Warp entire array by mapping target voxels to destination
# voxels by applying the inverse rotations.
# Implicitly assume the mapped image is of same size as the input image
shape = img.shape
num_voxels = np.prod(img.shape[: self.dim])
# Collect all voxels in dim x num_voxels format
if self.dim == 2:
target_voxels = list(itertools.product(*[range(shape[0]), range(shape[1])]))
elif self.dim == 3:
target_voxels = list(
itertools.product(*[range(shape[0]), range(shape[1]), range(shape[2])])
)
target_voxels = np.transpose(np.array(target_voxels))
# Find corresponding voxels in the original image
anchor_extruded = np.outer(self.anchor, np.ones(num_voxels))
src_voxels = anchor_extruded + self.rotation_inv.dot(
target_voxels - anchor_extruded
)
src_voxels = np.clip(
src_voxels.astype(int),
0,
np.outer(np.array(shape) - 1, np.ones(num_voxels)),
).astype(int)
rotated_img = np.zeros(shape)
rotated_img[tuple(target_voxels[j] for j in range(self.dim))] = img[
tuple(src_voxels[j] for j in range(self.dim))
]
return rotated_img
# print(target_voxels[:,34], src_voxels[:,34], img[src_voxels[:,34].astype(int)])
print(self.anchor + self.rotation_inv.dot(np.array([3, 4]) - self.anchor))
# Determine which columns lie within the image
mask = np.logical_and(
np.all(src_voxels > -1e-3, axis=0),
np.all(
np.less(src_voxels, np.outer(np.array(shape) - 1, np.ones(num_voxels))),
axis=0,
),
)
# Deactivate voxels outside range - continue only with valid ones
target_voxels = target_voxels[:, mask]
src_voxels = src_voxels[:, mask]
num_active_voxels = src_voxels.shape[1]
# Find related voxels in two opposite corners defining a voxel,
# containing src_voxels. Need to check whether the voxel is
# contained in the image.
base_corner = np.floor(src_voxels).astype(int)
opposite_corner = base_corner + 1
base_corner[
np.less(
opposite_corner,
np.outer(np.array(shape) - 1, np.ones(num_active_voxels)),
)
] -= 1
opposite_corner[
np.less(
opposite_corner,
np.outer(np.array(shape) - 1, np.ones(num_active_voxels)),
)
] -= 1
# Collect all corners of the marked voxels.
corners = [
np.vstack(comb)
for comb in list(
itertools.product(
*[[base_corner[i], opposite_corner[i]] for i in range(self.dim)]
)
)
]
# Evaluate standard Q1 basis functions, and scale them such that they sum to 1.
basis = [
np.prod(1 - np.absolute(corner - src_voxels), axis=0) for corner in corners
]
basis_sum = sum(basis)
basis = [b / basis_sum for b in basis]
# Bilinear / trilinear interpolation: sum(data in voxels * basis)
q1_interpolation = sum(
[
np.multiply(
basis[i],
img[
tuple(
np.array(corners[i][j]).astype(int) for j in range(self.dim)
)
],
)
for i in range(len(basis))
]
)
# Finally assign interpolated values to the associated target voxels.
rotated_img = np.zeros(shape)
rotated_img[
tuple(np.floor(target_voxels[j]).astype(int) for j in range(self.dim))
] = q1_interpolation
return rotated_img
# ! ---- I/O ----
[docs]
def save(self, path: Path) -> None:
raise NotImplementedError("Not implemented yet.")
[docs]
def load(self, path: Path) -> None:
raise NotImplementedError("Not implemented yet.")