# Licensed under a 3-clause BSD style license - see LICENSE.rst """ This module is to contain an improved bounding box """ import abc import copy from collections import namedtuple from typing import Dict, List, Tuple, Callable, Any from astropy.utils import isiterable from astropy.units import Quantity import warnings import numpy as np __all__ = ['ModelBoundingBox', 'CompoundBoundingBox'] _BaseInterval = namedtuple('_BaseInterval', "lower upper") class _Interval(_BaseInterval): """ A single input's bounding box interval. Parameters ---------- lower : float The lower bound of the interval upper : float The upper bound of the interval Methods ------- validate : Contructs a valid interval outside : Determine which parts of an input array are outside the interval. domain : Contructs a discretization of the points inside the interval. """ def __repr__(self): return f"Interval(lower={self.lower}, upper={self.upper})" def copy(self): return copy.deepcopy(self) @staticmethod def _validate_shape(interval): """Validate the shape of an interval representation""" MESSAGE = """An interval must be some sort of sequence of length 2""" try: shape = np.shape(interval) except TypeError: try: # np.shape does not work with lists of Quantities if len(interval) == 1: interval = interval[0] shape = np.shape([b.to_value() for b in interval]) except (ValueError, TypeError, AttributeError): raise ValueError(MESSAGE) valid_shape = shape in ((2,), (1, 2), (2, 0)) if not valid_shape: valid_shape = (len(shape) > 0) and (shape[0] == 2) and \ all(isinstance(b, np.ndarray) for b in interval) if not isiterable(interval) or not valid_shape: raise ValueError(MESSAGE) @classmethod def _validate_bounds(cls, lower, upper): """Validate the bounds are reasonable and construct an interval from them.""" if (np.asanyarray(lower) > np.asanyarray(upper)).all(): warnings.warn(f"Invalid interval: upper bound {upper} " f"is strictly less than lower bound {lower}.", RuntimeWarning) return cls(lower, upper) @classmethod def validate(cls, interval): """ Construct and validate an interval Parameters ---------- interval : iterable A representation of the interval. Returns ------- A validated interval. """ cls._validate_shape(interval) if len(interval) == 1: interval = tuple(interval[0]) else: interval = tuple(interval) return cls._validate_bounds(interval[0], interval[1]) def outside(self, _input: np.ndarray): """ Parameters ---------- _input : np.ndarray The evaluation input in the form of an array. Returns ------- Boolean array indicating which parts of _input are outside the interval: True -> position outside interval False -> position inside interval """ return np.logical_or(_input < self.lower, _input > self.upper) def domain(self, resolution): return np.arange(self.lower, self.upper + resolution, resolution) # The interval where all ignored inputs can be found. _ignored_interval = _Interval.validate((-np.inf, np.inf)) class _BoundingDomain(abc.ABC): """ Base class for ModelBoundingBox and CompoundBoundingBox. This is where all the `~astropy.modeling.core.Model` evaluation code for evaluating with a bounding box is because it is common to both types of bounding box. Parameters ---------- model : `~astropy.modeling.Model` The Model this bounding domain is for. prepare_inputs : Generates the necessary input information so that model can be evaluated only for input points entirely inside bounding_box. This needs to be implemented by a subclass. Note that most of the implementation is in ModelBoundingBox. prepare_outputs : Fills the output values in for any input points outside the bounding_box. evaluate : Performs a complete model evaluation while enforcing the bounds on the inputs and returns a complete output. """ def __init__(self, model): self._model = model def __call__(self, *args, **kwargs): raise NotImplementedError( "This bounding box is fixed by the model and does not have " "adjustable parameters.") @abc.abstractmethod def fix_inputs(self, model, fixed_inputs: dict): """ Fix the bounding_box for a `fix_inputs` compound model. Parameters ---------- model : `~astropy.modeling.Model` The new model for which this will be a bounding_box fixed_inputs : dict Dictionary of inputs which have been fixed by this bounding box. """ raise NotImplementedError("This should be implemented by a child class.") @abc.abstractmethod def prepare_inputs(self, input_shape, inputs) -> Tuple[Any, Any, Any]: """ Get prepare the inputs with respect to the bounding box. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- valid_inputs : list The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : array_like array of all indices inside the bounding box all_out: bool if all of the inputs are outside the bounding_box """ raise NotImplementedError("This has not been implemented for BoundingDomain.") @staticmethod def _base_output(input_shape, fill_value): """ Create a baseline output, assuming that the entire input is outside the bounding box Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box Returns ------- An array of the correct shape containing all fill_value """ return np.zeros(input_shape) + fill_value def _all_out_output(self, input_shape, fill_value): """ Create output if all inputs are outside the domain Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box Returns ------- A full set of outputs for case that all inputs are outside domain. """ return [self._base_output(input_shape, fill_value) for _ in range(self._model.n_outputs)], None def _modify_output(self, valid_output, valid_index, input_shape, fill_value): """ For a single output fill in all the parts corresponding to inputs outside the bounding box. Parameters ---------- valid_output : numpy array The output from the model corresponding to inputs inside the bounding box valid_index : numpy array array of all indices of inputs inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box Returns ------- An output array with all the indices corresponding to inputs outside the bounding box filled in by fill_value """ output = self._base_output(input_shape, fill_value) if not output.shape: output = np.array(valid_output) else: output[valid_index] = valid_output if np.isscalar(valid_output): output = output.item(0) return output def _prepare_outputs(self, valid_outputs, valid_index, input_shape, fill_value): """ Fill in all the outputs of the model corresponding to inputs outside the bounding_box. Parameters ---------- valid_outputs : list of numpy array The list of outputs from the model corresponding to inputs inside the bounding box valid_index : numpy array array of all indices of inputs inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box Returns ------- List of filled in output arrays. """ outputs = [] for valid_output in valid_outputs: outputs.append(self._modify_output(valid_output, valid_index, input_shape, fill_value)) return outputs def prepare_outputs(self, valid_outputs, valid_index, input_shape, fill_value): """ Fill in all the outputs of the model corresponding to inputs outside the bounding_box, adjusting any single output model so that its output becomes a list of containing that output. Parameters ---------- valid_outputs : list The list of outputs from the model corresponding to inputs inside the bounding box valid_index : array_like array of all indices of inputs inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box """ if self._model.n_outputs == 1: valid_outputs = [valid_outputs] return self._prepare_outputs(valid_outputs, valid_index, input_shape, fill_value) @staticmethod def _get_valid_outputs_unit(valid_outputs, with_units: bool): """ Get the unit for outputs if one is required. Parameters ---------- valid_outputs : list of numpy array The list of outputs from the model corresponding to inputs inside the bounding box with_units : bool whether or not a unit is required """ if with_units: return getattr(valid_outputs, 'unit', None) def _evaluate_model(self, evaluate: Callable, valid_inputs, valid_index, input_shape, fill_value, with_units: bool): """ Evaluate the model using the given evaluate routine Parameters ---------- evaluate : Callable callable which takes in the valid inputs to evaluate model valid_inputs : list of numpy arrays The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : numpy array array of all indices inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box with_units : bool whether or not a unit is required Returns ------- outputs : list containing filled in output values valid_outputs_unit : the unit that will be attached to the outputs """ valid_outputs = evaluate(valid_inputs) valid_outputs_unit = self._get_valid_outputs_unit(valid_outputs, with_units) return self.prepare_outputs(valid_outputs, valid_index, input_shape, fill_value), valid_outputs_unit def _evaluate(self, evaluate: Callable, inputs, input_shape, fill_value, with_units: bool): """ Perform model evaluation steps: prepare_inputs -> evaluate -> prepare_outputs Parameters ---------- evaluate : Callable callable which takes in the valid inputs to evaluate model valid_inputs : list of numpy arrays The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : numpy array array of all indices inside the bounding box input_shape : tuple The shape that all inputs have be reshaped/broadcasted into fill_value : float The value which will be assigned to inputs which are outside the bounding box with_units : bool whether or not a unit is required Returns ------- outputs : list containing filled in output values valid_outputs_unit : the unit that will be attached to the outputs """ valid_inputs, valid_index, all_out = self.prepare_inputs(input_shape, inputs) if all_out: return self._all_out_output(input_shape, fill_value) else: return self._evaluate_model(evaluate, valid_inputs, valid_index, input_shape, fill_value, with_units) @staticmethod def _set_outputs_unit(outputs, valid_outputs_unit): """ Set the units on the outputs prepare_inputs -> evaluate -> prepare_outputs -> set output units Parameters ---------- outputs : list containing filled in output values valid_outputs_unit : the unit that will be attached to the outputs Returns ------- List containing filled in output values and units """ if valid_outputs_unit is not None: return Quantity(outputs, valid_outputs_unit, copy=False) return outputs def evaluate(self, evaluate: Callable, inputs, fill_value): """ Perform full model evaluation steps: prepare_inputs -> evaluate -> prepare_outputs -> set output units Parameters ---------- evaluate : callable callable which takes in the valid inputs to evaluate model valid_inputs : list The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : array_like array of all indices inside the bounding box fill_value : float The value which will be assigned to inputs which are outside the bounding box """ input_shape = self._model.input_shape(inputs) # NOTE: CompoundModel does not currently support units during # evaluation for bounding_box so this feature is turned off # for CompoundModel(s). outputs, valid_outputs_unit = self._evaluate(evaluate, inputs, input_shape, fill_value, self._model.bbox_with_units) return tuple(self._set_outputs_unit(outputs, valid_outputs_unit)) def get_name(model, index: int): """Get the input name corresponding to the input index""" return model.inputs[index] def get_index(model, key) -> int: """ Get the input index corresponding to the given key. Can pass in either: the string name of the input or the input index itself. """ if isinstance(key, str): if key in model.inputs: index = model.inputs.index(key) else: raise ValueError(f"'{key}' is not one of the inputs: {model.inputs}.") elif np.issubdtype(type(key), np.integer): if 0 <= key < len(model.inputs): index = key else: raise IndexError(f"Integer key: {key} must be non-negative and < {len(model.inputs)}.") else: raise ValueError(f"Key value: {key} must be string or integer.") return index class ModelBoundingBox(_BoundingDomain): """ A model's bounding box Parameters ---------- intervals : dict A dictionary containing all the intervals for each model input keys -> input index values -> interval for that index model : `~astropy.modeling.Model` The Model this bounding_box is for. ignored : list A list containing all the inputs (index) which will not be checked for whether or not their elements are in/out of an interval. order : optional, str The ordering that is assumed for the tuple representation of this bounding_box. Options: 'C': C/Python order, e.g. z, y, x. (default), 'F': Fortran/mathematical notation order, e.g. x, y, z. """ def __init__(self, intervals: Dict[int, _Interval], model, ignored: List[int] = None, order: str = 'C'): super().__init__(model) self._order = order self._ignored = self._validate_ignored(ignored) self._intervals = {} if intervals != () and intervals != {}: self._validate(intervals, order=order) def copy(self, ignored=None): intervals = {index: interval.copy() for index, interval in self._intervals.items()} if ignored is None: ignored = self._ignored.copy() return ModelBoundingBox(intervals, self._model, ignored=ignored, order=self._order) @property def intervals(self) -> Dict[int, _Interval]: """Return bounding_box labeled using input positions""" return self._intervals @property def order(self) -> str: return self._order @property def ignored(self) -> List[int]: return self._ignored def _get_name(self, index: int): """Get the input name corresponding to the input index""" return get_name(self._model, index) @property def named_intervals(self) -> Dict[str, _Interval]: """Return bounding_box labeled using input names""" return {self._get_name(index): bbox for index, bbox in self._intervals.items()} @property def ignored_inputs(self) -> List[str]: return [self._get_name(index) for index in self._ignored] def __repr__(self): parts = [ 'ModelBoundingBox(', ' intervals={' ] for name, interval in self.named_intervals.items(): parts.append(f" {name}: {interval}") parts.append(' }') if len(self._ignored) > 0: parts.append(f" ignored={self.ignored_inputs}") parts.append(f' model={self._model.__class__.__name__}(inputs={self._model.inputs})') parts.append(f" order='{self._order}'") parts.append(')') return '\n'.join(parts) def _get_index(self, key) -> int: """ Get the input index corresponding to the given key. Can pass in either: the string name of the input or the input index itself. """ return get_index(self._model, key) def _validate_ignored(self, ignored: list) -> List[int]: if ignored is None: return [] else: return [self._get_index(key) for key in ignored] def __len__(self): return len(self._intervals) def __contains__(self, key): try: return self._get_index(key) in self._intervals or self._ignored except (IndexError, ValueError): return False def has_interval(self, key): return self._get_index(key) in self._intervals def __getitem__(self, key): """Get bounding_box entries by either input name or input index""" index = self._get_index(key) if index in self._ignored: return _ignored_interval else: return self._intervals[self._get_index(key)] def _get_order(self, order: str = None) -> str: """ Get if bounding_box is C/python ordered or Fortran/mathematically ordered """ if order is None: order = self._order if order not in ('C', 'F'): raise ValueError("order must be either 'C' (C/python order) or " f"'F' (Fortran/mathematical order), got: {order}.") return order def bounding_box(self, order: str = None): """ Return the old tuple of tuples representation of the bounding_box order='C' corresponds to the old bounding_box ordering order='F' corresponds to the gwcs bounding_box ordering. """ if len(self._intervals) == 1: return tuple(list(self._intervals.values())[0]) else: order = self._get_order(order) inputs = self._model.inputs if order == 'C': inputs = inputs[::-1] bbox = tuple([tuple(self[input_name]) for input_name in inputs]) if len(bbox) == 1: bbox = bbox[0] return bbox def __eq__(self, value): """Note equality can be either with old representation or new one.""" if isinstance(value, tuple): return self.bounding_box() == value elif isinstance(value, ModelBoundingBox): return (self.intervals == value.intervals) and (self.ignored == value.ignored) else: return False def __setitem__(self, key, value): """Validate and store interval under key (input index or input name).""" index = self._get_index(key) if index in self._ignored: self._ignored.remove(index) self._intervals[index] = _Interval.validate(value) def __delitem__(self, key): """Delete stored interval""" index = self._get_index(key) if index in self._ignored: raise RuntimeError(f"Cannot delete ignored input: {key}!") del self._intervals[index] self._ignored.append(index) def _validate_dict(self, bounding_box: dict): """Validate passing dictionary of intervals and setting them.""" for key, value in bounding_box.items(): self[key] = value def _validate_sequence(self, bounding_box, order: str = None): """Validate passing tuple of tuples representation (or related) and setting them.""" order = self._get_order(order) if order == 'C': # If bounding_box is C/python ordered, it needs to be reversed # to be in Fortran/mathematical/input order. bounding_box = bounding_box[::-1] for index, value in enumerate(bounding_box): self[index] = value @property def _n_inputs(self) -> int: n_inputs = self._model.n_inputs - len(self._ignored) if n_inputs > 0: return n_inputs else: return 0 def _validate_iterable(self, bounding_box, order: str = None): """Validate and set any iterable representation""" if len(bounding_box) != self._n_inputs: raise ValueError(f"Found {len(bounding_box)} intervals, " f"but must have exactly {self._n_inputs}.") if isinstance(bounding_box, dict): self._validate_dict(bounding_box) else: self._validate_sequence(bounding_box, order) def _validate(self, bounding_box, order: str = None): """Validate and set any representation""" if self._n_inputs == 1 and not isinstance(bounding_box, dict): self[0] = bounding_box else: self._validate_iterable(bounding_box, order) @classmethod def validate(cls, model, bounding_box, ignored: list = None, order: str = 'C', **kwargs): """ Construct a valid bounding box for a model. Parameters ---------- model : `~astropy.modeling.Model` The model for which this will be a bounding_box bounding_box : dict, tuple A possible representation of the bounding box order : optional, str The order that a tuple representation will be assumed to be Default: 'C' """ if isinstance(bounding_box, ModelBoundingBox): order = bounding_box.order bounding_box = bounding_box.intervals new = cls({}, model, ignored=ignored, order=order) new._validate(bounding_box) return new def fix_inputs(self, model, fixed_inputs: dict, _keep_ignored=False): """ Fix the bounding_box for a `fix_inputs` compound model. Parameters ---------- model : `~astropy.modeling.Model` The new model for which this will be a bounding_box fixed_inputs : dict Dictionary of inputs which have been fixed by this bounding box. keep_ignored : bool Keep the ignored inputs of the bounding box (internal argument only) """ new = self.copy() for _input in fixed_inputs.keys(): del new[_input] if _keep_ignored: ignored = new.ignored else: ignored = None return ModelBoundingBox.validate(model, new.named_intervals, ignored=ignored, order=new._order) @property def dimension(self): return len(self) def domain(self, resolution, order: str = None): inputs = self._model.inputs order = self._get_order(order) if order == 'C': inputs = inputs[::-1] return [self[input_name].domain(resolution) for input_name in inputs] def _outside(self, input_shape, inputs): """ Get all the input positions which are outside the bounding_box, so that the corresponding outputs can be filled with the fill value (default NaN). Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- outside_index : bool-numpy array True -> position outside bounding_box False -> position inside bounding_box all_out : bool if all of the inputs are outside the bounding_box """ all_out = False outside_index = np.zeros(input_shape, dtype=bool) for index, _input in enumerate(inputs): _input = np.asanyarray(_input) outside = np.broadcast_to(self[index].outside(_input), input_shape) outside_index[outside] = True if outside_index.all(): all_out = True break return outside_index, all_out def _valid_index(self, input_shape, inputs): """ Get the indices of all the inputs inside the bounding_box. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- valid_index : numpy array array of all indices inside the bounding box all_out : bool if all of the inputs are outside the bounding_box """ outside_index, all_out = self._outside(input_shape, inputs) valid_index = np.atleast_1d(np.logical_not(outside_index)).nonzero() if len(valid_index[0]) == 0: all_out = True return valid_index, all_out def prepare_inputs(self, input_shape, inputs) -> Tuple[Any, Any, Any]: """ Get prepare the inputs with respect to the bounding box. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- valid_inputs : list The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : array_like array of all indices inside the bounding box all_out: bool if all of the inputs are outside the bounding_box """ valid_index, all_out = self._valid_index(input_shape, inputs) valid_inputs = [] if not all_out: for _input in inputs: if input_shape: valid_input = np.broadcast_to(np.atleast_1d(_input), input_shape)[valid_index] if np.isscalar(_input): valid_input = valid_input.item(0) valid_inputs.append(valid_input) else: valid_inputs.append(_input) return tuple(valid_inputs), valid_index, all_out _BaseSelectorArgument = namedtuple('_BaseSelectorArgument', "index ignore") class _SelectorArgument(_BaseSelectorArgument): """ Contains a single CompoundBoundingBox slicing input. Parameters ---------- index : int The index of the input in the input list ignore : bool Whether or not this input will be ignored by the bounding box. Methods ------- validate : Returns a valid SelectorArgument for a given model. get_selector : Returns the value of the input for use in finding the correct bounding_box. get_fixed_value : Gets the slicing value from a fix_inputs set of values. """ def __new__(cls, index, ignore): self = super().__new__(cls, index, ignore) return self @classmethod def validate(cls, model, argument, ignored: bool = True): """ Construct a valid selector argument for a CompoundBoundingBox. Parameters ---------- model : `~astropy.modeling.Model` The model for which this will be an argument for. argument : int or str A representation of which evaluation input to use ignored : optional, bool Whether or not to ignore this argument in the ModelBoundingBox. Returns ------- Validated selector_argument """ return cls(get_index(model, argument), ignored) def get_selector(self, *inputs): """ Get the selector value corresponding to this argument Parameters ---------- *inputs : All the processed model evaluation inputs. """ _selector = inputs[self.index] if isiterable(_selector): if len(_selector) == 1: return _selector[0] else: return tuple(_selector) return _selector def name(self, model) -> str: """ Get the name of the input described by this selector argument Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. """ return get_name(model, self.index) def pretty_repr(self, model): """ Get a pretty-print representation of this object Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. """ return f"Argument(name='{self.name(model)}', ignore={self.ignore})" def get_fixed_value(self, model, values: dict): """ Gets the value fixed input corresponding to this argument Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. values : dict Dictionary of fixed inputs. """ if self.index in values: return values[self.index] else: if self.name(model) in values: return values[self.name(model)] else: raise RuntimeError(f"{self.pretty_repr(model)} was not found in {values}") def is_argument(self, model, argument) -> bool: """ Determine if passed argument is described by this selector argument Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. argument : int or str A representation of which evaluation input is being used """ return self.index == get_index(model, argument) def named_tuple(self, model): """ Get a tuple representation of this argument using the input name from the model. Parameters ---------- model : `~astropy.modeling.Model` The Model this selector argument is for. """ return (self.name(model), self.ignore) class _SelectorArguments(tuple): """ Contains the CompoundBoundingBox slicing description Parameters ---------- input_ : The SelectorArgument values Methods ------- validate : Returns a valid SelectorArguments for its model. get_selector : Returns the selector a set of inputs corresponds to. is_selector : Determines if a selector is correctly formatted for this CompoundBoundingBox. get_fixed_value : Gets the selector from a fix_inputs set of values. """ _kept_ignore = None def __new__(cls, input_: Tuple[_SelectorArgument], kept_ignore: List = None): self = super().__new__(cls, input_) if kept_ignore is None: self._kept_ignore = [] else: self._kept_ignore = kept_ignore return self def pretty_repr(self, model): """ Get a pretty-print representation of this object Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. """ parts = ['SelectorArguments('] for argument in self: parts.append( f" {argument.pretty_repr(model)}" ) parts.append(')') return '\n'.join(parts) @property def ignore(self): """Get the list of ignored inputs""" ignore = [argument.index for argument in self if argument.ignore] ignore.extend(self._kept_ignore) return ignore @property def kept_ignore(self): """The arguments to persist in ignoring""" return self._kept_ignore @classmethod def validate(cls, model, arguments, kept_ignore: List=None): """ Construct a valid Selector description for a CompoundBoundingBox. Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. arguments : The individual argument informations kept_ignore : Arguments to persist as ignored """ inputs = [] for argument in arguments: _input = _SelectorArgument.validate(model, *argument) if _input.index in [this.index for this in inputs]: raise ValueError(f"Input: '{get_name(model, _input.index)}' has been repeated.") inputs.append(_input) if len(inputs) == 0: raise ValueError("There must be at least one selector argument.") if isinstance(arguments, _SelectorArguments): if kept_ignore is None: kept_ignore = [] kept_ignore.extend(arguments.kept_ignore) return cls(tuple(inputs), kept_ignore) def get_selector(self, *inputs): """ Get the selector corresponding to these inputs Parameters ---------- *inputs : All the processed model evaluation inputs. """ return tuple([argument.get_selector(*inputs) for argument in self]) def is_selector(self, _selector): """ Determine if this is a reasonable selector Parameters ---------- _selector : tuple The selector to check """ return isinstance(_selector, tuple) and len(_selector) == len(self) def get_fixed_values(self, model, values: dict): """ Gets the value fixed input corresponding to this argument Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. values : dict Dictionary of fixed inputs. """ return tuple([argument.get_fixed_value(model, values) for argument in self]) def is_argument(self, model, argument) -> bool: """ Determine if passed argument is one of the selector arguments Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. argument : int or str A representation of which evaluation input is being used """ for selector_arg in self: if selector_arg.is_argument(model, argument): return True else: return False def selector_index(self, model, argument): """ Get the index of the argument passed in the selector tuples Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. argument : int or str A representation of which argument is being used """ for index, selector_arg in enumerate(self): if selector_arg.is_argument(model, argument): return index else: raise ValueError(f"{argument} does not correspond to any selector argument.") def reduce(self, model, argument): """ Reduce the selector arguments by the argument given Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. argument : int or str A representation of which argument is being used """ arguments = list(self) kept_ignore = [arguments.pop(self.selector_index(model, argument)).index] kept_ignore.extend(self._kept_ignore) return _SelectorArguments.validate(model, tuple(arguments), kept_ignore) def add_ignore(self, model, argument): """ Add argument to the kept_ignore list Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. argument : int or str A representation of which argument is being used """ if self.is_argument(model, argument): raise ValueError(f"{argument}: is a selector argument and cannot be ignored.") kept_ignore = [get_index(model, argument)] return _SelectorArguments.validate(model, self, kept_ignore) def named_tuple(self, model): """ Get a tuple of selector argument tuples using input names Parameters ---------- model : `~astropy.modeling.Model` The Model these selector arguments are for. """ return tuple([selector_arg.named_tuple(model) for selector_arg in self]) class CompoundBoundingBox(_BoundingDomain): """ A model's compound bounding box Parameters ---------- bounding_boxes : dict A dictionary containing all the ModelBoundingBoxes that are possible keys -> _selector (extracted from model inputs) values -> ModelBoundingBox model : `~astropy.modeling.Model` The Model this compound bounding_box is for. selector_args : _SelectorArguments A description of how to extract the selectors from model inputs. create_selector : optional A method which takes in the selector and the model to return a valid bounding corresponding to that selector. This can be used to construct new bounding_boxes for previously undefined selectors. These new boxes are then stored for future lookups. order : optional, str The ordering that is assumed for the tuple representation of the bounding_boxes. """ def __init__(self, bounding_boxes: Dict[Any, ModelBoundingBox], model, selector_args: _SelectorArguments, create_selector: Callable = None, order: str = 'C'): super().__init__(model) self._order = order self._create_selector = create_selector self._selector_args = _SelectorArguments.validate(model, selector_args) self._bounding_boxes = {} self._validate(bounding_boxes) def copy(self): bounding_boxes = {selector: bbox.copy(self.selector_args.ignore) for selector, bbox in self._bounding_boxes.items()} return CompoundBoundingBox(bounding_boxes, self._model, selector_args=self._selector_args, create_selector=copy.deepcopy(self._create_selector), order=self._order) def __repr__(self): parts = ['CompoundBoundingBox(', ' bounding_boxes={'] # bounding_boxes for _selector, bbox in self._bounding_boxes.items(): bbox_repr = bbox.__repr__().split('\n') parts.append(f" {_selector} = {bbox_repr.pop(0)}") for part in bbox_repr: parts.append(f" {part}") parts.append(' }') # selector_args selector_args_repr = self.selector_args.pretty_repr(self._model).split('\n') parts.append(f" selector_args = {selector_args_repr.pop(0)}") for part in selector_args_repr: parts.append(f" {part}") parts.append(')') return '\n'.join(parts) @property def bounding_boxes(self) -> Dict[Any, ModelBoundingBox]: return self._bounding_boxes @property def selector_args(self) -> _SelectorArguments: return self._selector_args @selector_args.setter def selector_args(self, value): self._selector_args = _SelectorArguments.validate(self._model, value) warnings.warn("Overriding selector_args may cause problems you should re-validate " "the compound bounding box before use!", RuntimeWarning) @property def named_selector_tuple(self) -> tuple: return self._selector_args.named_tuple(self._model) @property def create_selector(self): return self._create_selector @property def order(self) -> str: return self._order @staticmethod def _get_selector_key(key): if isiterable(key): return tuple(key) else: return (key,) def __setitem__(self, key, value): _selector = self._get_selector_key(key) if not self.selector_args.is_selector(_selector): raise ValueError(f"{_selector} is not a selector!") self._bounding_boxes[_selector] = ModelBoundingBox.validate(self._model, value, self.selector_args.ignore, order=self._order) def _validate(self, bounding_boxes: dict): for _selector, bounding_box in bounding_boxes.items(): self[_selector] = bounding_box def __eq__(self, value): if isinstance(value, CompoundBoundingBox): return (self.bounding_boxes == value.bounding_boxes) and \ (self.selector_args == value.selector_args) and \ (self.create_selector == value.create_selector) else: return False @classmethod def validate(cls, model, bounding_box: dict, selector_args=None, create_selector=None, order: str = 'C'): """ Construct a valid compound bounding box for a model. Parameters ---------- model : `~astropy.modeling.Model` The model for which this will be a bounding_box bounding_box : dict Dictionary of possible bounding_box respresentations selector_args : optional Description of the selector arguments create_selector : optional, callable Method for generating new selectors order : optional, str The order that a tuple representation will be assumed to be Default: 'C' """ if isinstance(bounding_box, CompoundBoundingBox): if selector_args is None: selector_args = bounding_box.selector_args if create_selector is None: create_selector = bounding_box.create_selector order = bounding_box.order bounding_box = bounding_box.bounding_boxes if selector_args is None: raise ValueError("Selector arguments must be provided (can be passed as part of bounding_box argument)!") return cls(bounding_box, model, selector_args, create_selector, order) def __contains__(self, key): return key in self._bounding_boxes def _create_bounding_box(self, _selector): self[_selector] = self._create_selector(_selector, model=self._model) return self[_selector] def __getitem__(self, key): _selector = self._get_selector_key(key) if _selector in self: return self._bounding_boxes[_selector] elif self._create_selector is not None: return self._create_bounding_box(_selector) else: raise RuntimeError(f"No bounding box is defined for selector: {_selector}.") def _select_bounding_box(self, inputs) -> ModelBoundingBox: _selector = self.selector_args.get_selector(*inputs) return self[_selector] def prepare_inputs(self, input_shape, inputs) -> Tuple[Any, Any, Any]: """ Get prepare the inputs with respect to the bounding box. Parameters ---------- input_shape : tuple The shape that all inputs have be reshaped/broadcasted into inputs : list List of all the model inputs Returns ------- valid_inputs : list The inputs reduced to just those inputs which are all inside their respective bounding box intervals valid_index : array_like array of all indices inside the bounding box all_out: bool if all of the inputs are outside the bounding_box """ bounding_box = self._select_bounding_box(inputs) return bounding_box.prepare_inputs(input_shape, inputs) def _matching_bounding_boxes(self, argument, value) -> Dict[Any, ModelBoundingBox]: selector_index = self.selector_args.selector_index(self._model, argument) matching = {} for selector_key, bbox in self._bounding_boxes.items(): if selector_key[selector_index] == value: new_selector_key = list(selector_key) new_selector_key.pop(selector_index) if bbox.has_interval(argument): new_bbox = bbox.fix_inputs(self._model, {argument: value}, _keep_ignored=True) else: new_bbox = bbox.copy() matching[tuple(new_selector_key)] = new_bbox if len(matching) == 0: raise ValueError(f"Attempting to fix input {argument}, but there are no " f"bounding boxes for argument value {value}.") return matching def _fix_input_selector_arg(self, argument, value): matching_bounding_boxes = self._matching_bounding_boxes(argument, value) if len(self.selector_args) == 1: return matching_bounding_boxes[()] else: return CompoundBoundingBox(matching_bounding_boxes, self._model, self.selector_args.reduce(self._model, argument)) def _fix_input_bbox_arg(self, argument, value): bounding_boxes = {} for selector_key, bbox in self._bounding_boxes.items(): bounding_boxes[selector_key] = bbox.fix_inputs(self._model, {argument: value}, _keep_ignored=True) return CompoundBoundingBox(bounding_boxes, self._model, self.selector_args.add_ignore(self._model, argument)) def fix_inputs(self, model, fixed_inputs: dict): """ Fix the bounding_box for a `fix_inputs` compound model. Parameters ---------- model : `~astropy.modeling.Model` The new model for which this will be a bounding_box fixed_inputs : dict Dictionary of inputs which have been fixed by this bounding box. """ fixed_input_keys = list(fixed_inputs.keys()) argument = fixed_input_keys.pop() value = fixed_inputs[argument] if self.selector_args.is_argument(self._model, argument): bbox = self._fix_input_selector_arg(argument, value) else: bbox = self._fix_input_bbox_arg(argument, value) if len(fixed_input_keys) > 0: new_fixed_inputs = fixed_inputs.copy() del new_fixed_inputs[argument] bbox = bbox.fix_inputs(model, new_fixed_inputs) if isinstance(bbox, CompoundBoundingBox): selector_args = bbox.named_selector_tuple bbox_dict = bbox elif isinstance(bbox, ModelBoundingBox): selector_args = None bbox_dict = bbox.named_intervals return bbox.__class__.validate(model, bbox_dict, order=bbox.order, selector_args=selector_args)