# Licensed under a 3-clause BSD style license - see LICENSE.rst """ This module tests some of the methods related to YAML serialization. """ from io import StringIO import pytest import numpy as np from astropy.coordinates import (SkyCoord, EarthLocation, Angle, Longitude, Latitude, SphericalRepresentation, UnitSphericalRepresentation, CartesianRepresentation, SphericalCosLatDifferential, SphericalDifferential, CartesianDifferential) from astropy import units as u from astropy.time import Time from astropy.table import QTable, SerializedColumn from astropy.coordinates.tests.test_representation import representation_equal from astropy.io.misc.yaml import load, load_all, dump # noqa @pytest.mark.parametrize('c', [True, np.uint8(8), np.int16(4), np.int32(1), np.int64(3), np.int64(2**63 - 1), 2.0, np.float64(), 3+4j, np.complex_(3 + 4j), np.complex64(3 + 4j), np.complex128(1. - 2**-52 + 1j * (1. - 2**-52))]) def test_numpy_types(c): cy = load(dump(c)) assert c == cy @pytest.mark.parametrize('c', [u.m, u.m / u.s, u.hPa, u.dimensionless_unscaled, u.Unit('m, (cm, um)')]) def test_unit(c): cy = load(dump(c)) if isinstance(c, (u.CompositeUnit, u.StructuredUnit)): assert c == cy else: assert c is cy @pytest.mark.parametrize('c', [u.Unit('bakers_dozen', 13*u.one), u.def_unit('magic')]) def test_custom_unit(c): s = dump(c) with pytest.warns(u.UnitsWarning, match=f"'{c!s}' did not parse") as w: cy = load(s) assert len(w) == 1 assert isinstance(cy, u.UnrecognizedUnit) assert str(cy) == str(c) with u.add_enabled_units(c): cy2 = load(s) assert cy2 is c @pytest.mark.parametrize('c', [ Angle('1 2 3', unit='deg'), Longitude('1 2 3', unit='deg'), Latitude('1 2 3', unit='deg'), [[1], [3]] * u.m, np.array([[1, 2], [3, 4]], order='F'), np.array([[1, 2], [3, 4]], order='C'), np.array([1, 2, 3, 4])[::2], np.array([(1., 2), (3., 4)], dtype='f8,i4'), # array with structured dtype. np.array((1., 2), dtype='f8,i4'), # array scalar with structured dtype. np.array((1., 2), dtype='f8,i4')[()], # numpy void. np.array((1., 2.), dtype='f8,f8') * u.s, # Quantity structured scalar. [((1., 2., 3.), (4., 5., 6.)), # Quantity with structured unit. ((11., 12., 13.), (14., 15., 16.))] * u.Unit('m, m/s'), np.array([((1., 2., 3.), (4., 5., 6.)), ((11., 12., 13.), (14., 15., 16.))], dtype=[('p', '3f8'), ('v', '3f8')]) * u.Unit('m, m/s') ]) def test_ndarray_subclasses(c): cy = load(dump(c)) assert np.all(c == cy) assert c.shape == cy.shape assert c.dtype == cy.dtype assert type(c) is type(cy) cc = 'C_CONTIGUOUS' fc = 'F_CONTIGUOUS' if c.flags[cc] or c.flags[fc]: assert c.flags[cc] == cy.flags[cc] assert c.flags[fc] == cy.flags[fc] else: # Original was not contiguous but round-trip version # should be c-contig. assert cy.flags[cc] if hasattr(c, 'unit'): assert c.unit == cy.unit def compare_coord(c, cy): assert c.shape == cy.shape assert c.frame.name == cy.frame.name assert list(c.get_frame_attr_names()) == list(cy.get_frame_attr_names()) for attr in c.get_frame_attr_names(): assert getattr(c, attr) == getattr(cy, attr) assert (list(c.representation_component_names) == list(cy.representation_component_names)) for name in c.representation_component_names: assert np.all(getattr(c, attr) == getattr(cy, attr)) @pytest.mark.parametrize('frame', ['fk4', 'altaz']) def test_skycoord(frame): c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]], unit='deg', frame=frame, obstime=Time('2016-01-02'), location=EarthLocation(1000, 2000, 3000, unit=u.km)) cy = load(dump(c)) compare_coord(c, cy) @pytest.mark.parametrize('rep', [ CartesianRepresentation(1*u.m, 2.*u.m, 3.*u.m), SphericalRepresentation([[1, 2], [3, 4]]*u.deg, [[5, 6], [7, 8]]*u.deg, 10*u.pc), UnitSphericalRepresentation(0*u.deg, 10*u.deg), SphericalCosLatDifferential([[1.], [2.]]*u.mas/u.yr, [4., 5.]*u.mas/u.yr, [[[10]], [[20]]]*u.km/u.s), CartesianDifferential([10, 20, 30]*u.km/u.s), CartesianRepresentation( [1, 2, 3]*u.m, differentials=CartesianDifferential([10, 20, 30]*u.km/u.s)), SphericalRepresentation( [[1, 2], [3, 4]]*u.deg, [[5, 6], [7, 8]]*u.deg, 10*u.pc, differentials={ 's': SphericalDifferential([[0., 1.], [2., 3.]]*u.mas/u.yr, [[4., 5.], [6., 7.]]*u.mas/u.yr, 10*u.km/u.s)})]) def test_representations(rep): rrep = load(dump(rep)) assert np.all(representation_equal(rrep, rep)) def _get_time(): t = Time([[1], [2]], format='cxcsec', location=EarthLocation(1000, 2000, 3000, unit=u.km)) t.format = 'iso' t.precision = 5 t.delta_ut1_utc = np.array([[3.0], [4.0]]) t.delta_tdb_tt = np.array([[5.0], [6.0]]) t.out_subfmt = 'date_hm' return t def compare_time(t, ty): assert type(t) is type(ty) assert np.all(t == ty) for attr in ('shape', 'jd1', 'jd2', 'format', 'scale', 'precision', 'in_subfmt', 'out_subfmt', 'location', 'delta_ut1_utc', 'delta_tdb_tt'): assert np.all(getattr(t, attr) == getattr(ty, attr)) def test_time(): t = _get_time() ty = load(dump(t)) compare_time(t, ty) def test_timedelta(): t = _get_time() dt = t - t + 0.1234556 * u.s dty = load(dump(dt)) assert type(dt) is type(dty) for attr in ('shape', 'jd1', 'jd2', 'format', 'scale'): assert np.all(getattr(dt, attr) == getattr(dty, attr)) def test_serialized_column(): sc = SerializedColumn({'name': 'hello', 'other': 1, 'other2': 2.0}) scy = load(dump(sc)) assert sc == scy def test_load_all(): t = _get_time() unit = u.m / u.s c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]], unit='deg', frame='fk4', obstime=Time('2016-01-02'), location=EarthLocation(1000, 2000, 3000, unit=u.km)) # Make a multi-document stream out = ('---\n' + dump(t) + '---\n' + dump(unit) + '---\n' + dump(c)) ty, unity, cy = list(load_all(out)) compare_time(t, ty) compare_coord(c, cy) assert unity == unit def test_ecsv_astropy_objects_in_meta(): """ Test that astropy core objects in ``meta`` are serialized. """ t = QTable([[1, 2] * u.m, [4, 5]], names=['a', 'b']) tm = _get_time() c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]], unit='deg', frame='fk4', obstime=Time('2016-01-02'), location=EarthLocation(1000, 2000, 3000, unit=u.km)) unit = u.m / u.s t.meta = {'tm': tm, 'c': c, 'unit': unit} out = StringIO() t.write(out, format='ascii.ecsv') t2 = QTable.read(out.getvalue(), format='ascii.ecsv') compare_time(tm, t2.meta['tm']) compare_coord(c, t2.meta['c']) assert t2.meta['unit'] == unit