# Licensed under a 3-clause BSD style license - see LICENSE.rst """ Functions to determine if a model is separable, i.e. if the model outputs are independent. It analyzes ``n_inputs``, ``n_outputs`` and the operators in a compound model by stepping through the transforms and creating a ``coord_matrix`` of shape (``n_outputs``, ``n_inputs``). Each modeling operator is represented by a function which takes two simple models (or two ``coord_matrix`` arrays) and returns an array of shape (``n_outputs``, ``n_inputs``). """ import numpy as np from .core import Model, ModelDefinitionError, CompoundModel from .mappings import Mapping __all__ = ["is_separable", "separability_matrix"] def is_separable(transform): """ A separability test for the outputs of a transform. Parameters ---------- transform : `~astropy.modeling.core.Model` A (compound) model. Returns ------- is_separable : ndarray A boolean array with size ``transform.n_outputs`` where each element indicates whether the output is independent and the result of a separable transform. Examples -------- >>> from astropy.modeling.models import Shift, Scale, Rotation2D, Polynomial2D >>> is_separable(Shift(1) & Shift(2) | Scale(1) & Scale(2)) array([ True, True]...) >>> is_separable(Shift(1) & Shift(2) | Rotation2D(2)) array([False, False]...) >>> is_separable(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]) | \ Polynomial2D(1) & Polynomial2D(2)) array([False, False]...) >>> is_separable(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1])) array([ True, True, True, True]...) """ if transform.n_inputs == 1 and transform.n_outputs > 1: is_separable = np.array([False] * transform.n_outputs).T return is_separable separable_matrix = _separable(transform) is_separable = separable_matrix.sum(1) is_separable = np.where(is_separable != 1, False, True) return is_separable def separability_matrix(transform): """ Compute the correlation between outputs and inputs. Parameters ---------- transform : `~astropy.modeling.core.Model` A (compound) model. Returns ------- separable_matrix : ndarray A boolean correlation matrix of shape (n_outputs, n_inputs). Indicates the dependence of outputs on inputs. For completely independent outputs, the diagonal elements are True and off-diagonal elements are False. Examples -------- >>> from astropy.modeling.models import Shift, Scale, Rotation2D, Polynomial2D >>> separability_matrix(Shift(1) & Shift(2) | Scale(1) & Scale(2)) array([[ True, False], [False, True]]...) >>> separability_matrix(Shift(1) & Shift(2) | Rotation2D(2)) array([[ True, True], [ True, True]]...) >>> separability_matrix(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1]) | \ Polynomial2D(1) & Polynomial2D(2)) array([[ True, True], [ True, True]]...) >>> separability_matrix(Shift(1) & Shift(2) | Mapping([0, 1, 0, 1])) array([[ True, False], [False, True], [ True, False], [False, True]]...) """ if transform.n_inputs == 1 and transform.n_outputs > 1: return np.ones((transform.n_outputs, transform.n_inputs), dtype=np.bool_) separable_matrix = _separable(transform) separable_matrix = np.where(separable_matrix != 0, True, False) return separable_matrix def _compute_n_outputs(left, right): """ Compute the number of outputs of two models. The two models are the left and right model to an operation in the expression tree of a compound model. Parameters ---------- left, right : `astropy.modeling.Model` or ndarray If input is of an array, it is the output of `coord_matrix`. """ if isinstance(left, Model): lnout = left.n_outputs else: lnout = left.shape[0] if isinstance(right, Model): rnout = right.n_outputs else: rnout = right.shape[0] noutp = lnout + rnout return noutp def _arith_oper(left, right): """ Function corresponding to one of the arithmetic operators ['+', '-'. '*', '/', '**']. This always returns a nonseparable output. Parameters ---------- left, right : `astropy.modeling.Model` or ndarray If input is of an array, it is the output of `coord_matrix`. Returns ------- result : ndarray Result from this operation. """ # models have the same number of inputs and outputs def _n_inputs_outputs(input): if isinstance(input, Model): n_outputs, n_inputs = input.n_outputs, input.n_inputs else: n_outputs, n_inputs = input.shape return n_inputs, n_outputs left_inputs, left_outputs = _n_inputs_outputs(left) right_inputs, right_outputs = _n_inputs_outputs(right) if left_inputs != right_inputs or left_outputs != right_outputs: raise ModelDefinitionError( "Unsupported operands for arithmetic operator: left (n_inputs={}, " "n_outputs={}) and right (n_inputs={}, n_outputs={}); " "models must have the same n_inputs and the same " "n_outputs for this operator.".format( left_inputs, left_outputs, right_inputs, right_outputs)) result = np.ones((left_outputs, left_inputs)) return result def _coord_matrix(model, pos, noutp): """ Create an array representing inputs and outputs of a simple model. The array has a shape (noutp, model.n_inputs). Parameters ---------- model : `astropy.modeling.Model` model pos : str Position of this model in the expression tree. One of ['left', 'right']. noutp : int Number of outputs of the compound model of which the input model is a left or right child. """ if isinstance(model, Mapping): axes = [] for i in model.mapping: axis = np.zeros((model.n_inputs,)) axis[i] = 1 axes.append(axis) m = np.vstack(axes) mat = np.zeros((noutp, model.n_inputs)) if pos == 'left': mat[: model.n_outputs, :model.n_inputs] = m else: mat[-model.n_outputs:, -model.n_inputs:] = m return mat if not model.separable: # this does not work for more than 2 coordinates mat = np.zeros((noutp, model.n_inputs)) if pos == 'left': mat[:model.n_outputs, : model.n_inputs] = 1 else: mat[-model.n_outputs:, -model.n_inputs:] = 1 else: mat = np.zeros((noutp, model.n_inputs)) for i in range(model.n_inputs): mat[i, i] = 1 if pos == 'right': mat = np.roll(mat, (noutp - model.n_outputs)) return mat def _cstack(left, right): """ Function corresponding to '&' operation. Parameters ---------- left, right : `astropy.modeling.Model` or ndarray If input is of an array, it is the output of `coord_matrix`. Returns ------- result : ndarray Result from this operation. """ noutp = _compute_n_outputs(left, right) if isinstance(left, Model): cleft = _coord_matrix(left, 'left', noutp) else: cleft = np.zeros((noutp, left.shape[1])) cleft[: left.shape[0], : left.shape[1]] = left if isinstance(right, Model): cright = _coord_matrix(right, 'right', noutp) else: cright = np.zeros((noutp, right.shape[1])) cright[-right.shape[0]:, -right.shape[1]:] = right return np.hstack([cleft, cright]) def _cdot(left, right): """ Function corresponding to "|" operation. Parameters ---------- left, right : `astropy.modeling.Model` or ndarray If input is of an array, it is the output of `coord_matrix`. Returns ------- result : ndarray Result from this operation. """ left, right = right, left def _n_inputs_outputs(input, position): """ Return ``n_inputs``, ``n_outputs`` for a model or coord_matrix. """ if isinstance(input, Model): coords = _coord_matrix(input, position, input.n_outputs) else: coords = input return coords cleft = _n_inputs_outputs(left, 'left') cright = _n_inputs_outputs(right, 'right') try: result = np.dot(cleft, cright) except ValueError: raise ModelDefinitionError( 'Models cannot be combined with the "|" operator; ' 'left coord_matrix is {}, right coord_matrix is {}'.format( cright, cleft)) return result def _separable(transform): """ Calculate the separability of outputs. Parameters ---------- transform : `astropy.modeling.Model` A transform (usually a compound model). Returns : is_separable : ndarray of dtype np.bool An array of shape (transform.n_outputs,) of boolean type Each element represents the separablity of the corresponding output. """ if (transform_matrix := transform._calculate_separability_matrix()) is not NotImplemented: return transform_matrix elif isinstance(transform, CompoundModel): sepleft = _separable(transform.left) sepright = _separable(transform.right) return _operators[transform.op](sepleft, sepright) elif isinstance(transform, Model): return _coord_matrix(transform, 'left', transform.n_outputs) # Maps modeling operators to a function computing and represents the # relationship of axes as an array of 0-es and 1-s _operators = {'&': _cstack, '|': _cdot, '+': _arith_oper, '-': _arith_oper, '*': _arith_oper, '/': _arith_oper, '**': _arith_oper}