# Licensed under a 3-clause BSD style license - see LICENSE.rst """ Tests that relate to evaluating models with quantity parameters """ # pylint: disable=invalid-name, no-member import numpy as np import pytest from numpy.testing import assert_allclose from astropy.modeling.core import Model from astropy.modeling.models import Gaussian1D, Shift, Scale, Pix2Sky_TAN from astropy import units as u from astropy.units import UnitsError from astropy.tests.helper import assert_quantity_allclose # We start off by taking some simple cases where the units are defined by # whatever the model is initialized with, and we check that the model evaluation # returns quantities. def test_evaluate_with_quantities(): """ Test evaluation of a single model with Quantity parameters that do not explicitly require units. """ # We create two models here - one with quantities, and one without. The one # without is used to create the reference values for comparison. g = Gaussian1D(1, 1, 0.1) gq = Gaussian1D(1 * u.J, 1 * u.m, 0.1 * u.m) # We first check that calling the Gaussian with quantities returns the # expected result assert_quantity_allclose(gq(1 * u.m), g(1) * u.J) # Units have to be specified for the Gaussian with quantities - if not, an # error is raised with pytest.raises(UnitsError) as exc: gq(1) assert exc.value.args[0] == ("Gaussian1D: Units of input 'x', (dimensionless), could not be " "converted to required input units of m (length)") # However, zero is a special case assert_quantity_allclose(gq(0), g(0) * u.J) # We can also evaluate models with equivalent units assert_allclose(gq(0.0005 * u.km).value, g(0.5)) # But not with incompatible units with pytest.raises(UnitsError) as exc: gq(3 * u.s) assert exc.value.args[0] == ("Gaussian1D: Units of input 'x', s (time), could not be " "converted to required input units of m (length)") # We also can't evaluate the model without quantities with a quantity with pytest.raises(UnitsError) as exc: g(3 * u.m) # TODO: determine what error message should be here # assert exc.value.args[0] == ("Units of input 'x', m (length), could not be " # "converted to required dimensionless input") def test_evaluate_with_quantities_and_equivalencies(): """ We now make sure that equivalencies are correctly taken into account """ g = Gaussian1D(1 * u.Jy, 10 * u.nm, 2 * u.nm) # We aren't setting the equivalencies, so this won't work with pytest.raises(UnitsError) as exc: g(30 * u.PHz) assert exc.value.args[0] == ("Gaussian1D: Units of input 'x', PHz (frequency), could " "not be converted to required input units of " "nm (length)") # But it should now work if we pass equivalencies when evaluating assert_quantity_allclose(g(30 * u.PHz, equivalencies={'x': u.spectral()}), g(9.993081933333332 * u.nm)) class MyTestModel(Model): n_inputs = 2 n_outputs = 1 def evaluate(self, a, b): print('a', a) print('b', b) return a * b class TestInputUnits(): def setup_method(self, method): self.model = MyTestModel() def test_evaluate(self): # We should be able to evaluate with anything assert_quantity_allclose(self.model(3, 5), 15) assert_quantity_allclose(self.model(4 * u.m, 5), 20 * u.m) assert_quantity_allclose(self.model(3 * u.deg, 5), 15 * u.deg) def test_input_units(self): self.model._input_units = {'x': u.deg} assert_quantity_allclose(self.model(3 * u.deg, 4), 12 * u.deg) assert_quantity_allclose(self.model(4 * u.rad, 2), 8 * u.rad) assert_quantity_allclose(self.model(4 * u.rad, 2 * u.s), 8 * u.rad * u.s) with pytest.raises(UnitsError) as exc: self.model(4 * u.s, 3) assert exc.value.args[0] == ("MyTestModel: Units of input 'x', s (time), could not be " "converted to required input units of deg (angle)") with pytest.raises(UnitsError) as exc: self.model(3, 3) assert exc.value.args[0] == ("MyTestModel: Units of input 'x', (dimensionless), could " "not be converted to required input units of deg (angle)") def test_input_units_allow_dimensionless(self): self.model._input_units = {'x': u.deg} self.model._input_units_allow_dimensionless = True assert_quantity_allclose(self.model(3 * u.deg, 4), 12 * u.deg) assert_quantity_allclose(self.model(4 * u.rad, 2), 8 * u.rad) with pytest.raises(UnitsError) as exc: self.model(4 * u.s, 3) assert exc.value.args[0] == ("MyTestModel: Units of input 'x', s (time), could not be " "converted to required input units of deg (angle)") assert_quantity_allclose(self.model(3, 3), 9) def test_input_units_strict(self): self.model._input_units = {'x': u.deg} self.model._input_units_strict = True assert_quantity_allclose(self.model(3 * u.deg, 4), 12 * u.deg) result = self.model(np.pi * u.rad, 2) assert_quantity_allclose(result, 360 * u.deg) assert result.unit is u.deg def test_input_units_equivalencies(self): self.model._input_units = {'x': u.micron} with pytest.raises(UnitsError) as exc: self.model(3 * u.PHz, 3) assert exc.value.args[0] == ("MyTestModel: Units of input 'x', PHz (frequency), could " "not be converted to required input units of " "micron (length)") self.model.input_units_equivalencies = {'x': u.spectral()} assert_quantity_allclose(self.model(3 * u.PHz, 3), 3 * (3 * u.PHz).to(u.micron, equivalencies=u.spectral())) def test_return_units(self): self.model._input_units = {'z': u.deg} self.model._return_units = {'z': u.rad} result = self.model(3 * u.deg, 4) assert_quantity_allclose(result, 12 * u.deg) assert result.unit is u.rad def test_return_units_scalar(self): # Check that return_units also works when giving a single unit since # there is only one output, so is unambiguous. self.model._input_units = {'x': u.deg} self.model._return_units = u.rad result = self.model(3 * u.deg, 4) assert_quantity_allclose(result, 12 * u.deg) assert result.unit is u.rad def test_and_input_units(): """ Test units to first model in chain. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 & s2 out = cs(10 * u.arcsecond, 20 * u.arcsecond) assert_quantity_allclose(out[0], 10 * u.deg + 10 * u.arcsec) assert_quantity_allclose(out[1], 10 * u.deg + 20 * u.arcsec) def test_plus_input_units(): """ Test units to first model in chain. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 + s2 out = cs(10 * u.arcsecond) assert_quantity_allclose(out, 20 * u.deg + 20 * u.arcsec) def test_compound_input_units(): """ Test units to first model in chain. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 | s2 out = cs(10 * u.arcsecond) assert_quantity_allclose(out, 20 * u.deg + 10 * u.arcsec) def test_compound_input_units_fail(): """ Test incompatible units to first model in chain. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 | s2 with pytest.raises(UnitsError): cs(10 * u.pix) def test_compound_incompatible_units_fail(): """ Test incompatible model units in chain. """ s1 = Shift(10 * u.pix) s2 = Shift(10 * u.deg) cs = s1 | s2 with pytest.raises(UnitsError): cs(10 * u.pix) def test_compound_pipe_equiv_call(): """ Check that equivalencies work when passed to evaluate, for a chained model (which has one input). """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 | s2 out = cs(10 * u.pix, equivalencies={'x': u.pixel_scale(0.5 * u.deg / u.pix)}) assert_quantity_allclose(out, 25 * u.deg) def test_compound_and_equiv_call(): """ Check that equivalencies work when passed to evaluate, for a composite model with two inputs. """ s1 = Shift(10 * u.deg) s2 = Shift(10 * u.deg) cs = s1 & s2 out = cs(10 * u.pix, 10 * u.pix, equivalencies={'x0': u.pixel_scale(0.5 * u.deg / u.pix), 'x1': u.pixel_scale(0.5 * u.deg / u.pix)}) assert_quantity_allclose(out[0], 15 * u.deg) assert_quantity_allclose(out[1], 15 * u.deg) def test_compound_input_units_equivalencies(): """ Test setting input_units_equivalencies on one of the models. """ s1 = Shift(10 * u.deg) s1.input_units_equivalencies = {'x': u.pixel_scale(0.5 * u.deg / u.pix)} s2 = Shift(10 * u.deg) sp = Shift(10 * u.pix) cs = s1 | s2 assert cs.input_units_equivalencies == {'x': u.pixel_scale(0.5 * u.deg / u.pix)} out = cs(10 * u.pix) assert_quantity_allclose(out, 25 * u.deg) cs = sp | s1 assert cs.input_units_equivalencies is None out = cs(10 * u.pix) assert_quantity_allclose(out, 20 * u.deg) cs = s1 & s2 assert cs.input_units_equivalencies == {'x0': u.pixel_scale(0.5 * u.deg / u.pix)} cs = cs.rename('TestModel') out = cs(20 * u.pix, 10 * u.deg) assert_quantity_allclose(out, 20 * u.deg) with pytest.raises(UnitsError) as exc: out = cs(20 * u.pix, 10 * u.pix) assert exc.value.args[0] == "Shift: Units of input 'x', pix (unknown), could not be converted to required input units of deg (angle)" def test_compound_input_units_strict(): """ Test setting input_units_strict on one of the models. """ class ScaleDegrees(Scale): input_units = {'x': u.deg} s1 = ScaleDegrees(2) s2 = Scale(2) cs = s1 | s2 out = cs(10 * u.arcsec) assert_quantity_allclose(out, 40 * u.arcsec) assert out.unit is u.deg # important since this tests input_units_strict cs = s2 | s1 out = cs(10 * u.arcsec) assert_quantity_allclose(out, 40 * u.arcsec) assert out.unit is u.deg # important since this tests input_units_strict cs = s1 & s2 out = cs(10 * u.arcsec, 10 * u.arcsec) assert_quantity_allclose(out, 20 * u.arcsec) assert out[0].unit is u.deg assert out[1].unit is u.arcsec def test_compound_input_units_allow_dimensionless(): """ Test setting input_units_allow_dimensionless on one of the models. """ class ScaleDegrees(Scale): input_units = {'x': u.deg} s1 = ScaleDegrees(2) s1._input_units_allow_dimensionless = True s2 = Scale(2) cs = s1 | s2 cs = cs.rename('TestModel') out = cs(10) assert_quantity_allclose(out, 40 * u.one) out = cs(10 * u.arcsec) assert_quantity_allclose(out, 40 * u.arcsec) with pytest.raises(UnitsError) as exc: out = cs(10 * u.m) assert exc.value.args[0] == ("ScaleDegrees: Units of input 'x', m (length), " "could not be converted to required input units of deg (angle)") s1._input_units_allow_dimensionless = False cs = s1 | s2 cs = cs.rename('TestModel') with pytest.raises(UnitsError) as exc: out = cs(10) assert exc.value.args[0] == ("ScaleDegrees: Units of input 'x', (dimensionless), " "could not be converted to required input units of deg (angle)") s1._input_units_allow_dimensionless = True cs = s2 | s1 cs = cs.rename('TestModel') out = cs(10) assert_quantity_allclose(out, 40 * u.one) out = cs(10 * u.arcsec) assert_quantity_allclose(out, 40 * u.arcsec) with pytest.raises(UnitsError) as exc: out = cs(10 * u.m) assert exc.value.args[0] == ("ScaleDegrees: Units of input 'x', m (length), " "could not be converted to required input units of deg (angle)") s1._input_units_allow_dimensionless = False cs = s2 | s1 with pytest.raises(UnitsError) as exc: out = cs(10) assert exc.value.args[0] == ("ScaleDegrees: Units of input 'x', (dimensionless), " "could not be converted to required input units of deg (angle)") s1._input_units_allow_dimensionless = True s1 = ScaleDegrees(2) s1._input_units_allow_dimensionless = True s2 = ScaleDegrees(2) s2._input_units_allow_dimensionless = False cs = s1 & s2 cs = cs.rename('TestModel') out = cs(10, 10 * u.arcsec) assert_quantity_allclose(out[0], 20 * u.one) assert_quantity_allclose(out[1], 20 * u.arcsec) with pytest.raises(UnitsError) as exc: out = cs(10, 10) assert exc.value.args[0] == ("ScaleDegrees: Units of input 'x', (dimensionless), " "could not be converted to required input units of deg (angle)") def test_compound_return_units(): """ Test that return_units on the first model in the chain is respected for the input to the second. """ class PassModel(Model): n_inputs = 2 n_outputs = 2 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @property def input_units(self): """ Input units. """ return {'x0': u.deg, 'x1': u.deg} @property def return_units(self): """ Output units. """ return {'x0': u.deg, 'x1': u.deg} def evaluate(self, x, y): return x.value, y.value cs = Pix2Sky_TAN() | PassModel() assert_quantity_allclose(cs(0*u.deg, 0*u.deg), (0, 90)*u.deg)