Source code for darsia.signals.models.combinedmodel

"""Combination of models.

NOTE: Combining models is experimental and the responsibility
lies with the user.

"""

from __future__ import annotations

from typing import Literal, Optional, Union

import numpy as np

import darsia


[docs] class CombinedModel(darsia.Model): def __init__(self, models: list[darsia.Model]) -> None: # Cache models self.models = models # Determine the number of parameters needed for calibration self.num_parameters = sum( [ model.num_parameters if hasattr(model, "num_parameters") else 0 for model in self.models ] ) def __call__(self, img: np.ndarray, *args) -> np.ndarray: """ concatenate the application of the models args: img (np.ndarray): input image returns: np.ndarray: combined model response """ result = img.copy() for model in self.models: # determine the number of arguments in the signature of # the model (__call__) and pass only a suitable amount # of arguments. note: there is no guarantee that the # the models use the same positional arguments! all_args = model.__call__.__code__.co_argcount if all_args == 2: result = model(result) else: result = model(result, *args[: all_args - 2]) return result
[docs] def update_model_parameters( self, parameters: np.ndarray, dofs: Optional[Union[list[tuple[int, str]], Literal["all"]]] = None, ) -> None: """ Wrapper of update routines of single models. Args: parameters (np.ndarray): parameter array pos_model (int): position index addressing a single model. """ # Cache a copy of the parameters parameters_cache = parameters.copy() # Update the parameters of the model if dofs in [None, "all"]: # If no degrees of freedom are specified, update all parameters for pos_model, model in enumerate(self.models): model.update_model_parameters(parameters_cache) # Remove the updated parameters from the cache parameters_cache = parameters_cache[model.num_parameters :] else: # Analogously when only a subset of parameters is to be updated for pos_model, pos_parameter in dofs: model = self.models[pos_model] model.update_model_parameters(parameters_cache, pos_parameter) parameters_cache = parameters_cache[model.num_parameters :]
def __getitem__(self, pos_model: int) -> darsia.Model: """Access single models. Args: pos_model (int): position index addressing a single model. Returns: darsia.Model: single model """ return self.models[pos_model]