# coding: utf-8 # Licensed under a 3-clause BSD style license - see LICENSE.rst """ Test Structured units and quantities. """ import copy import pytest import numpy as np from numpy.testing import assert_array_equal from astropy import units as u from astropy.units import StructuredUnit, Unit, UnitBase, Quantity from astropy.tests.helper import pickle_protocol, check_pickling_recovery from astropy.utils.compat import NUMPY_LT_1_21_1 from astropy.utils.masked import Masked class StructuredTestBase: @classmethod def setup_class(self): self.pv_dtype = np.dtype([('p', 'f8'), ('v', 'f8')]) self.pv_t_dtype = np.dtype([('pv', self.pv_dtype), ('t', 'f8')]) self.p_unit = u.km self.v_unit = u.km / u.s self.t_unit = u.s self.pv_dtype = np.dtype([('p', 'f8'), ('v', 'f8')]) self.pv_t_dtype = np.dtype([('pv', self.pv_dtype), ('t', 'f8')]) self.pv = np.array([(1., 0.25), (2., 0.5), (3., 0.75)], self.pv_dtype) self.pv_t = np.array([((4., 2.5), 0.), ((5., 5.0), 1.), ((6., 7.5), 2.)], self.pv_t_dtype) class StructuredTestBaseWithUnits(StructuredTestBase): @classmethod def setup_class(self): super().setup_class() self.pv_unit = StructuredUnit((self.p_unit, self.v_unit), ('p', 'v')) self.pv_t_unit = StructuredUnit((self.pv_unit, self.t_unit), ('pv', 't')) class TestStructuredUnitBasics(StructuredTestBase): def test_initialization_and_keying(self): su = StructuredUnit((self.p_unit, self.v_unit), ('p', 'v')) assert su['p'] is self.p_unit assert su['v'] is self.v_unit su2 = StructuredUnit((su, self.t_unit), ('pv', 't')) assert isinstance(su2['pv'], StructuredUnit) assert su2['pv']['p'] is self.p_unit assert su2['pv']['v'] is self.v_unit assert su2['t'] is self.t_unit assert su2['pv'] == su su3 = StructuredUnit(('AU', 'AU/day'), ('p', 'v')) assert isinstance(su3['p'], UnitBase) assert isinstance(su3['v'], UnitBase) su4 = StructuredUnit('AU, AU/day', ('p', 'v')) assert su4['p'] == u.AU assert su4['v'] == u.AU / u.day su5 = StructuredUnit(('AU', 'AU/day')) assert su5.field_names == ('f0', 'f1') assert su5['f0'] == u.AU assert su5['f1'] == u.AU / u.day def test_recursive_initialization(self): su = StructuredUnit(((self.p_unit, self.v_unit), self.t_unit), (('p', 'v'), 't')) assert isinstance(su['pv'], StructuredUnit) assert su['pv']['p'] is self.p_unit assert su['pv']['v'] is self.v_unit assert su['t'] is self.t_unit su2 = StructuredUnit(((self.p_unit, self.v_unit), self.t_unit), (['p_v', ('p', 'v')], 't')) assert isinstance(su2['p_v'], StructuredUnit) assert su2['p_v']['p'] is self.p_unit assert su2['p_v']['v'] is self.v_unit assert su2['t'] is self.t_unit su3 = StructuredUnit((('AU', 'AU/day'), 'yr'), (['p_v', ('p', 'v')], 't')) assert isinstance(su3['p_v'], StructuredUnit) assert su3['p_v']['p'] == u.AU assert su3['p_v']['v'] == u.AU / u.day assert su3['t'] == u.yr su4 = StructuredUnit('(AU, AU/day), yr', (('p', 'v'), 't')) assert isinstance(su4['pv'], StructuredUnit) assert su4['pv']['p'] == u.AU assert su4['pv']['v'] == u.AU / u.day assert su4['t'] == u.yr def test_extreme_recursive_initialization(self): su = StructuredUnit('(yr,(AU,AU/day,(km,(day,day))),m)', ('t', ('p', 'v', ('h', ('d1', 'd2'))), 'l')) assert su.field_names == ('t', ['pvhd1d2', ('p', 'v', ['hd1d2', ('h', ['d1d2', ('d1', 'd2')])])], 'l') @pytest.mark.parametrize('names, invalid', [ [('t', ['p', 'v']), "['p', 'v']"], [('t', ['pv', 'p', 'v']), "['pv', 'p', 'v']"], [('t', ['pv', ['p', 'v']]), "['pv', ['p', 'v']"], [('t', ()), "()"], [('t', ('p', None)), "None"], [('t', ['pv', ('p', '')]), "''"]]) def test_initialization_names_invalid_list_errors(self, names, invalid): with pytest.raises(ValueError) as exc: StructuredUnit('(yr,(AU,AU/day)', names) assert f'invalid entry {invalid}' in str(exc) def test_looks_like_unit(self): su = StructuredUnit((self.p_unit, self.v_unit), ('p', 'v')) assert Unit(su) is su def test_initialize_with_float_dtype(self): su = StructuredUnit(('AU', 'AU/d'), self.pv_dtype) assert isinstance(su['p'], UnitBase) assert isinstance(su['v'], UnitBase) assert su['p'] == u.AU assert su['v'] == u.AU / u.day su = StructuredUnit((('km', 'km/s'), 'yr'), self.pv_t_dtype) assert isinstance(su['pv'], StructuredUnit) assert isinstance(su['pv']['p'], UnitBase) assert isinstance(su['t'], UnitBase) assert su['pv']['v'] == u.km / u.s su = StructuredUnit('(km, km/s), yr', self.pv_t_dtype) assert isinstance(su['pv'], StructuredUnit) assert isinstance(su['pv']['p'], UnitBase) assert isinstance(su['t'], UnitBase) assert su['pv']['v'] == u.km / u.s def test_initialize_with_structured_unit_for_names(self): su = StructuredUnit(('AU', 'AU/d'), names=('p', 'v')) su2 = StructuredUnit(('km', 'km/s'), names=su) assert su2.field_names == ('p', 'v') assert su2['p'] == u.km assert su2['v'] == u.km / u.s def test_initialize_single_field(self): su = StructuredUnit('AU', 'p') assert isinstance(su, StructuredUnit) assert isinstance(su['p'], UnitBase) assert su['p'] == u.AU su = StructuredUnit('AU') assert isinstance(su, StructuredUnit) assert isinstance(su['f0'], UnitBase) assert su['f0'] == u.AU def test_equality(self): su = StructuredUnit(('AU', 'AU/d'), self.pv_dtype) assert su == StructuredUnit(('AU', 'AU/d'), self.pv_dtype) assert su != StructuredUnit(('m', 'AU/d'), self.pv_dtype) # Names should be ignored. assert su == StructuredUnit(('AU', 'AU/d')) assert su == StructuredUnit(('AU', 'AU/d'), names=('q', 'w')) assert su != StructuredUnit(('m', 'm/s')) def test_parsing(self): su = Unit('AU, AU/d') assert isinstance(su, StructuredUnit) assert isinstance(su['f0'], UnitBase) assert isinstance(su['f1'], UnitBase) assert su['f0'] == u.AU assert su['f1'] == u.AU/u.day su2 = Unit('AU, AU/d, yr') assert isinstance(su2, StructuredUnit) assert su2 == StructuredUnit(('AU', 'AU/d', 'yr')) su2a = Unit('(AU, AU/d, yr)') assert isinstance(su2a, StructuredUnit) assert su2a == su2 su3 = Unit('(km, km/s), yr') assert isinstance(su3, StructuredUnit) assert su3 == StructuredUnit((('km', 'km/s'), 'yr')) su4 = Unit('km,') assert isinstance(su4, StructuredUnit) assert su4 == StructuredUnit((u.km,)) su5 = Unit('(m,s),') assert isinstance(su5, StructuredUnit) assert su5 == StructuredUnit(((u.m, u.s),)) ldbody_unit = Unit('Msun, 0.5rad^2, (au, au/day)') assert ldbody_unit == StructuredUnit( (u.Msun, Unit(u.rad**2 / 2), (u.AU, u.AU / u.day))) def test_str(self): su = StructuredUnit(((u.km, u.km/u.s), u.yr)) assert str(su) == '((km, km / s), yr)' assert Unit(str(su)) == su def test_repr(self): su = StructuredUnit(((u.km, u.km/u.s), u.yr)) assert repr(su) == 'Unit("((km, km / s), yr)")' assert eval(repr(su)) == su class TestStructuredUnitsCopyPickle(StructuredTestBaseWithUnits): def test_copy(self): su_copy = copy.copy(self.pv_t_unit) assert su_copy is not self.pv_t_unit assert su_copy == self.pv_t_unit assert su_copy._units is self.pv_t_unit._units def test_deepcopy(self): su_copy = copy.deepcopy(self.pv_t_unit) assert su_copy is not self.pv_t_unit assert su_copy == self.pv_t_unit assert su_copy._units is not self.pv_t_unit._units @pytest.mark.skipif(NUMPY_LT_1_21_1, reason="https://stackoverflow.com/q/69571643") def test_pickle(self, pickle_protocol): check_pickling_recovery(self.pv_t_unit, pickle_protocol) class TestStructuredUnitAsMapping(StructuredTestBaseWithUnits): def test_len(self): assert len(self.pv_unit) == 2 assert len(self.pv_t_unit) == 2 def test_keys(self): slv = list(self.pv_t_unit.keys()) assert slv == ['pv', 't'] def test_values(self): values = self.pv_t_unit.values() assert values == (self.pv_unit, self.t_unit) def test_field_names(self): field_names = self.pv_t_unit.field_names assert isinstance(field_names, tuple) assert field_names == (['pv', ('p', 'v')], 't') @pytest.mark.parametrize('iterable', [list, set]) def test_as_iterable(self, iterable): sl = iterable(self.pv_unit) assert isinstance(sl, iterable) assert sl == iterable(['p', 'v']) def test_as_dict(self): sd = dict(self.pv_t_unit) assert sd == {'pv': self.pv_unit, 't': self.t_unit} def test_contains(self): assert 'p' in self.pv_unit assert 'v' in self.pv_unit assert 't' not in self.pv_unit def test_setitem_fails(self): with pytest.raises(TypeError, match='item assignment'): self.pv_t_unit['t'] = u.Gyr class TestStructuredUnitMethods(StructuredTestBaseWithUnits): def test_physical_type_id(self): pv_ptid = self.pv_unit._get_physical_type_id() assert len(pv_ptid) == 2 assert pv_ptid.dtype.names == ('p', 'v') p_ptid = self.pv_unit['p']._get_physical_type_id() v_ptid = self.pv_unit['v']._get_physical_type_id() # Expected should be (subclass of) void, with structured object dtype. expected = np.array((p_ptid, v_ptid), [('p', 'O'), ('v', 'O')])[()] assert pv_ptid == expected # Names should be ignored in comparison. assert pv_ptid == np.array((p_ptid, v_ptid), 'O,O')[()] # Should be possible to address by field and by number. assert pv_ptid['p'] == p_ptid assert pv_ptid['v'] == v_ptid assert pv_ptid[0] == p_ptid assert pv_ptid[1] == v_ptid # More complicated version. pv_t_ptid = self.pv_t_unit._get_physical_type_id() t_ptid = self.t_unit._get_physical_type_id() assert pv_t_ptid == np.array((pv_ptid, t_ptid), 'O,O')[()] assert pv_t_ptid['pv'] == pv_ptid assert pv_t_ptid['t'] == t_ptid assert pv_t_ptid['pv'][1] == v_ptid def test_physical_type(self): pv_pt = self.pv_unit.physical_type assert pv_pt == np.array(('length', 'speed'), 'O,O')[()] pv_t_pt = self.pv_t_unit.physical_type assert pv_t_pt == np.array((pv_pt, 'time'), 'O,O')[()] def test_si(self): pv_t_si = self.pv_t_unit.si assert pv_t_si == self.pv_t_unit assert pv_t_si['pv']['v'].scale == 1000 def test_cgs(self): pv_t_cgs = self.pv_t_unit.cgs assert pv_t_cgs == self.pv_t_unit assert pv_t_cgs['pv']['v'].scale == 100000 def test_decompose(self): pv_t_decompose = self.pv_t_unit.decompose() assert pv_t_decompose['pv']['v'].scale == 1000 def test_is_equivalent(self): assert self.pv_unit.is_equivalent(('AU', 'AU/day')) assert not self.pv_unit.is_equivalent('m') assert not self.pv_unit.is_equivalent(('AU', 'AU')) # Names should be ignored. pv_alt = StructuredUnit('m,m/s', names=('q', 'w')) assert pv_alt.field_names != self.pv_unit.field_names assert self.pv_unit.is_equivalent(pv_alt) # Regular units should work too. assert not u.m.is_equivalent(self.pv_unit) def test_conversion(self): pv1 = self.pv_unit.to(('AU', 'AU/day'), self.pv) assert isinstance(pv1, np.ndarray) assert pv1.dtype == self.pv.dtype assert np.all(pv1['p'] * u.AU == self.pv['p'] * self.p_unit) assert np.all(pv1['v'] * u.AU / u.day == self.pv['v'] * self.v_unit) # Names should be from value. su2 = StructuredUnit((self.p_unit, self.v_unit), ('position', 'velocity')) pv2 = su2.to(('Mm', 'mm/s'), self.pv) assert pv2.dtype.names == ('p', 'v') assert pv2.dtype == self.pv.dtype # Check recursion. pv_t1 = self.pv_t_unit.to((('AU', 'AU/day'), 'Myr'), self.pv_t) assert isinstance(pv_t1, np.ndarray) assert pv_t1.dtype == self.pv_t.dtype assert np.all(pv_t1['pv']['p'] * u.AU == self.pv_t['pv']['p'] * self.p_unit) assert np.all(pv_t1['pv']['v'] * u.AU / u.day == self.pv_t['pv']['v'] * self.v_unit) assert np.all(pv_t1['t'] * u.Myr == self.pv_t['t'] * self.t_unit) # Passing in tuples should work. pv_t2 = self.pv_t_unit.to((('AU', 'AU/day'), 'Myr'), ((1., 0.1), 10.)) assert pv_t2['pv']['p'] == self.p_unit.to('AU', 1.) assert pv_t2['pv']['v'] == self.v_unit.to('AU/day', 0.1) assert pv_t2['t'] == self.t_unit.to('Myr', 10.) pv_t3 = self.pv_t_unit.to((('AU', 'AU/day'), 'Myr'), [((1., 0.1), 10.), ((2., 0.2), 20.)]) assert np.all(pv_t3['pv']['p'] == self.p_unit.to('AU', [1., 2.])) assert np.all(pv_t3['pv']['v'] == self.v_unit.to('AU/day', [0.1, 0.2])) assert np.all(pv_t3['t'] == self.t_unit.to('Myr', [10., 20.])) class TestStructuredUnitArithmatic(StructuredTestBaseWithUnits): def test_multiplication(self): pv_times_au = self.pv_unit * u.au assert isinstance(pv_times_au, StructuredUnit) assert pv_times_au.field_names == ('p', 'v') assert pv_times_au['p'] == self.p_unit * u.AU assert pv_times_au['v'] == self.v_unit * u.AU au_times_pv = u.au * self.pv_unit assert au_times_pv == pv_times_au pv_times_au2 = self.pv_unit * 'au' assert pv_times_au2 == pv_times_au au_times_pv2 = 'AU' * self.pv_unit assert au_times_pv2 == pv_times_au with pytest.raises(TypeError): self.pv_unit * self.pv_unit with pytest.raises(TypeError): 's,s' * self.pv_unit def test_division(self): pv_by_s = self.pv_unit / u.s assert isinstance(pv_by_s, StructuredUnit) assert pv_by_s.field_names == ('p', 'v') assert pv_by_s['p'] == self.p_unit / u.s assert pv_by_s['v'] == self.v_unit / u.s pv_by_s2 = self.pv_unit / 's' assert pv_by_s2 == pv_by_s with pytest.raises(TypeError): 1. / self.pv_unit with pytest.raises(TypeError): u.s / self.pv_unit class TestStructuredQuantity(StructuredTestBaseWithUnits): def test_initialization_and_keying(self): q_pv = Quantity(self.pv, self.pv_unit) q_p = q_pv['p'] assert isinstance(q_p, Quantity) assert isinstance(q_p.unit, UnitBase) assert np.all(q_p == self.pv['p'] * self.pv_unit['p']) q_v = q_pv['v'] assert isinstance(q_v, Quantity) assert isinstance(q_v.unit, UnitBase) assert np.all(q_v == self.pv['v'] * self.pv_unit['v']) q_pv_t = Quantity(self.pv_t, self.pv_t_unit) q_t = q_pv_t['t'] assert np.all(q_t == self.pv_t['t'] * self.pv_t_unit['t']) q_pv2 = q_pv_t['pv'] assert isinstance(q_pv2, Quantity) assert q_pv2.unit == self.pv_unit with pytest.raises(ValueError): Quantity(self.pv, self.pv_t_unit) with pytest.raises(ValueError): Quantity(self.pv_t, self.pv_unit) def test_initialization_with_unit_tuples(self): q_pv_t = Quantity(self.pv_t, (('km', 'km/s'), 's')) assert isinstance(q_pv_t.unit, StructuredUnit) assert q_pv_t.unit == self.pv_t_unit def test_initialization_with_string(self): q_pv_t = Quantity(self.pv_t, '(km, km/s), s') assert isinstance(q_pv_t.unit, StructuredUnit) assert q_pv_t.unit == self.pv_t_unit def test_initialization_by_multiplication_with_unit(self): q_pv_t = self.pv_t * self.pv_t_unit assert q_pv_t.unit is self.pv_t_unit assert np.all(q_pv_t.value == self.pv_t) assert not np.may_share_memory(q_pv_t, self.pv_t) q_pv_t2 = self.pv_t_unit * self.pv_t assert q_pv_t.unit is self.pv_t_unit # Not testing equality of structured Quantity here. assert np.all(q_pv_t2.value == q_pv_t.value) def test_initialization_by_shifting_to_unit(self): q_pv_t = self.pv_t << self.pv_t_unit assert q_pv_t.unit is self.pv_t_unit assert np.all(q_pv_t.value == self.pv_t) assert np.may_share_memory(q_pv_t, self.pv_t) def test_getitem(self): q_pv_t = Quantity(self.pv_t, self.pv_t_unit) q_pv_t01 = q_pv_t[:2] assert isinstance(q_pv_t01, Quantity) assert q_pv_t01.unit == q_pv_t.unit assert np.all(q_pv_t01['t'] == q_pv_t['t'][:2]) q_pv_t1 = q_pv_t[1] assert isinstance(q_pv_t1, Quantity) assert q_pv_t1.unit == q_pv_t.unit assert q_pv_t1.shape == () assert q_pv_t1['t'] == q_pv_t['t'][1] def test_value(self): q_pv_t = Quantity(self.pv_t, self.pv_t_unit) value = q_pv_t.value assert type(value) is np.ndarray assert np.all(value == self.pv_t) value1 = q_pv_t[1].value assert type(value1) is np.void assert np.all(value1 == self.pv_t[1]) def test_conversion(self): q_pv = Quantity(self.pv, self.pv_unit) q1 = q_pv.to(('AU', 'AU/day')) assert isinstance(q1, Quantity) assert q1['p'].unit == u.AU assert q1['v'].unit == u.AU / u.day assert np.all(q1['p'] == q_pv['p'].to(u.AU)) assert np.all(q1['v'] == q_pv['v'].to(u.AU/u.day)) q2 = q_pv.to(self.pv_unit) assert q2['p'].unit == self.p_unit assert q2['v'].unit == self.v_unit assert np.all(q2['p'].value == self.pv['p']) assert np.all(q2['v'].value == self.pv['v']) assert not np.may_share_memory(q2, q_pv) pv1 = q_pv.to_value(('AU', 'AU/day')) assert type(pv1) is np.ndarray assert np.all(pv1['p'] == q_pv['p'].to_value(u.AU)) assert np.all(pv1['v'] == q_pv['v'].to_value(u.AU/u.day)) pv11 = q_pv[1].to_value(('AU', 'AU/day')) assert type(pv11) is np.void assert pv11 == pv1[1] q_pv_t = Quantity(self.pv_t, self.pv_t_unit) q2 = q_pv_t.to((('kpc', 'kpc/Myr'), 'Myr')) assert q2['pv']['p'].unit == u.kpc assert q2['pv']['v'].unit == u.kpc / u.Myr assert q2['t'].unit == u.Myr assert np.all(q2['pv']['p'] == q_pv_t['pv']['p'].to(u.kpc)) assert np.all(q2['pv']['v'] == q_pv_t['pv']['v'].to(u.kpc/u.Myr)) assert np.all(q2['t'] == q_pv_t['t'].to(u.Myr)) def test_conversion_via_lshift(self): q_pv = Quantity(self.pv, self.pv_unit) q1 = q_pv << StructuredUnit(('AU', 'AU/day')) assert isinstance(q1, Quantity) assert q1['p'].unit == u.AU assert q1['v'].unit == u.AU / u.day assert np.all(q1['p'] == q_pv['p'].to(u.AU)) assert np.all(q1['v'] == q_pv['v'].to(u.AU/u.day)) q2 = q_pv << self.pv_unit assert q2['p'].unit == self.p_unit assert q2['v'].unit == self.v_unit assert np.all(q2['p'].value == self.pv['p']) assert np.all(q2['v'].value == self.pv['v']) assert np.may_share_memory(q2, q_pv) q_pv_t = Quantity(self.pv_t, self.pv_t_unit) q2 = q_pv_t << '(kpc,kpc/Myr),Myr' assert q2['pv']['p'].unit == u.kpc assert q2['pv']['v'].unit == u.kpc / u.Myr assert q2['t'].unit == u.Myr assert np.all(q2['pv']['p'] == q_pv_t['pv']['p'].to(u.kpc)) assert np.all(q2['pv']['v'] == q_pv_t['pv']['v'].to(u.kpc/u.Myr)) assert np.all(q2['t'] == q_pv_t['t'].to(u.Myr)) def test_inplace_conversion(self): q_pv = Quantity(self.pv, self.pv_unit) q1 = q_pv.copy() q_link = q1 q1 <<= StructuredUnit(('AU', 'AU/day')) assert q1 is q_link assert q1['p'].unit == u.AU assert q1['v'].unit == u.AU / u.day assert np.all(q1['p'] == q_pv['p'].to(u.AU)) assert np.all(q1['v'] == q_pv['v'].to(u.AU/u.day)) q_pv_t = Quantity(self.pv_t, self.pv_t_unit) q2 = q_pv_t.copy() q_link = q2 q2 <<= '(kpc,kpc/Myr),Myr' assert q2 is q_link assert q2['pv']['p'].unit == u.kpc assert q2['pv']['v'].unit == u.kpc / u.Myr assert q2['t'].unit == u.Myr assert np.all(q2['pv']['p'] == q_pv_t['pv']['p'].to(u.kpc)) assert np.all(q2['pv']['v'] == q_pv_t['pv']['v'].to(u.kpc/u.Myr)) assert np.all(q2['t'] == q_pv_t['t'].to(u.Myr)) def test_si(self): q_pv_t = Quantity(self.pv_t, self.pv_t_unit) q_pv_t_si = q_pv_t.si assert_array_equal(q_pv_t_si, q_pv_t.to('(m,m/s),s')) def test_cgs(self): q_pv_t = Quantity(self.pv_t, self.pv_t_unit) q_pv_t_cgs = q_pv_t.cgs assert_array_equal(q_pv_t_cgs, q_pv_t.to('(cm,cm/s),s')) def test_equality(self): q_pv = Quantity(self.pv, self.pv_unit) equal = q_pv == q_pv not_equal = q_pv != q_pv assert np.all(equal) assert not np.any(not_equal) equal2 = q_pv == q_pv[1] not_equal2 = q_pv != q_pv[1] assert np.all(equal2 == [False, True, False]) assert np.all(not_equal2 != equal2) q1 = q_pv.to(('AU', 'AU/day')) # Ensure same conversion is done, by placing q1 first. assert np.all(q1 == q_pv) assert not np.any(q1 != q_pv) # Check different names in dtype. assert np.all(q1.value * u.Unit('AU, AU/day') == q_pv) assert not np.any(q1.value * u.Unit('AU, AU/day') != q_pv) assert (q_pv == 'b') is False assert ('b' != q_pv) is True q_pv_t = Quantity(self.pv_t, self.pv_t_unit) assert np.all((q_pv_t[2] == q_pv_t) == [False, False, True]) assert np.all((q_pv_t[2] != q_pv_t) != [False, False, True]) assert (q_pv == q_pv_t) is False assert (q_pv_t != q_pv) is True def test_setitem(self): q_pv = Quantity(self.pv, self.pv_unit) q_pv[1] = (2., 2.) * self.pv_unit assert q_pv[1].value == np.array((2., 2.), self.pv_dtype) q_pv[1:2] = (1., 0.5) * u.Unit('AU, AU/day') assert q_pv['p'][1] == 1. * u.AU assert q_pv['v'][1] == 0.5 * u.AU / u.day q_pv['v'] = 1. * u.km / u.s assert np.all(q_pv['v'] == 1. * u.km / u.s) with pytest.raises(u.UnitsError): q_pv[1] = (1., 1.) * u.Unit('AU, AU') with pytest.raises(u.UnitsError): q_pv['v'] = 1. * u.km q_pv_t = Quantity(self.pv_t, self.pv_t_unit) q_pv_t[1] = ((2., 2.), 3.) * self.pv_t_unit assert q_pv_t[1].value == np.array(((2., 2.), 3.), self.pv_t_dtype) q_pv_t[1:2] = ((1., 0.5), 5.) * u.Unit('(AU, AU/day), yr') assert q_pv_t['pv'][1] == (1., 0.5) * u.Unit('AU, AU/day') assert q_pv_t['t'][1] == 5. * u.yr q_pv_t['pv'] = (1., 0.5) * self.pv_unit assert np.all(q_pv_t['pv'] == (1., 0.5) * self.pv_unit) class TestStructuredQuantityFunctions(StructuredTestBaseWithUnits): @classmethod def setup_class(self): super().setup_class() self.q_pv = self.pv << self.pv_unit self.q_pv_t = self.pv_t << self.pv_t_unit def test_empty_like(self): z = np.empty_like(self.q_pv) assert z.dtype == self.pv_dtype assert z.unit == self.pv_unit assert z.shape == self.pv.shape @pytest.mark.parametrize('func', [np.zeros_like, np.ones_like]) def test_zeros_ones_like(self, func): z = func(self.q_pv) assert z.dtype == self.pv_dtype assert z.unit == self.pv_unit assert z.shape == self.pv.shape assert_array_equal(z, func(self.pv) << self.pv_unit) class TestStructuredSpecificTypeQuantity(StructuredTestBaseWithUnits): def setup_class(self): super().setup_class() class PositionVelocity(u.SpecificTypeQuantity): _equivalent_unit = self.pv_unit self.PositionVelocity = PositionVelocity def test_init(self): pv = self.PositionVelocity(self.pv, self.pv_unit) assert isinstance(pv, self.PositionVelocity) assert type(pv['p']) is u.Quantity assert_array_equal(pv['p'], self.pv['p'] << self.pv_unit['p']) pv2 = self.PositionVelocity(self.pv, 'AU,AU/day') assert_array_equal(pv2['p'], self.pv['p'] << u.AU) def test_error_on_non_equivalent_unit(self): with pytest.raises(u.UnitsError): self.PositionVelocity(self.pv, 'AU') with pytest.raises(u.UnitsError): self.PositionVelocity(self.pv, 'AU,yr') class TestStructuredLogUnit: def setup_class(self): self.mag_time_dtype = np.dtype([('mag', 'f8'), ('t', 'f8')]) self.mag_time = np.array([(20., 10.), (25., 100.)], self.mag_time_dtype) def test_unit_initialization(self): mag_time_unit = StructuredUnit((u.STmag, u.s), self.mag_time_dtype) assert mag_time_unit['mag'] == u.STmag assert mag_time_unit['t'] == u.s mag_time_unit2 = u.Unit('mag(ST),s') assert mag_time_unit2 == mag_time_unit def test_quantity_initialization(self): su = u.Unit('mag(ST),s') mag_time = self.mag_time << su assert isinstance(mag_time['mag'], u.Magnitude) assert isinstance(mag_time['t'], u.Quantity) assert mag_time.unit == su assert_array_equal(mag_time['mag'], self.mag_time['mag'] << u.STmag) assert_array_equal(mag_time['t'], self.mag_time['t'] << u.s) def test_quantity_si(self): mag_time = self.mag_time << u.Unit('mag(ST),yr') mag_time_si = mag_time.si assert_array_equal(mag_time_si['mag'], mag_time['mag'].si) assert_array_equal(mag_time_si['t'], mag_time['t'].si) class TestStructuredMaskedQuantity(StructuredTestBaseWithUnits): """Somewhat minimal tests. Conversion is most stringent.""" def setup_class(self): super().setup_class() self.qpv = self.pv << self.pv_unit self.pv_mask = np.array([(True, False), (False, False), (False, True)], [('p', bool), ('v', bool)]) self.mpv = Masked(self.qpv, mask=self.pv_mask) def test_init(self): assert isinstance(self.mpv, Masked) assert isinstance(self.mpv, Quantity) assert_array_equal(self.mpv.unmasked, self.qpv) assert_array_equal(self.mpv.mask, self.pv_mask) def test_slicing(self): mp = self.mpv['p'] assert isinstance(mp, Masked) assert isinstance(mp, Quantity) assert_array_equal(mp.unmasked, self.qpv['p']) assert_array_equal(mp.mask, self.pv_mask['p']) def test_conversion(self): mpv = self.mpv.to('AU,AU/day') assert isinstance(mpv, Masked) assert isinstance(mpv, Quantity) assert_array_equal(mpv.unmasked, self.qpv.to('AU,AU/day')) assert_array_equal(mpv.mask, self.pv_mask) assert np.all(mpv == self.mpv) def test_si(self): mpv = self.mpv.si assert isinstance(mpv, Masked) assert isinstance(mpv, Quantity) assert_array_equal(mpv.unmasked, self.qpv.si) assert_array_equal(mpv.mask, self.pv_mask) assert np.all(mpv == self.mpv)