# Functions/classes for WCSAxes related to APE14 WCSes import numpy as np from astropy.coordinates import SkyCoord, ICRS, BaseCoordinateFrame from astropy import units as u from astropy.wcs import WCS from astropy.wcs.utils import local_partial_pixel_derivatives from astropy.wcs.wcsapi import SlicedLowLevelWCS from .frame import RectangularFrame, EllipticalFrame, RectangularFrame1D from .transforms import CurvedTransform __all__ = ['transform_coord_meta_from_wcs', 'WCSWorld2PixelTransform', 'WCSPixel2WorldTransform'] IDENTITY = WCS(naxis=2) IDENTITY.wcs.ctype = ["X", "Y"] IDENTITY.wcs.crval = [0., 0.] IDENTITY.wcs.crpix = [1., 1.] IDENTITY.wcs.cdelt = [1., 1.] def transform_coord_meta_from_wcs(wcs, frame_class, slices=None): if slices is not None: slices = tuple(slices) if wcs.pixel_n_dim > 2: if slices is None: raise ValueError("WCS has more than 2 pixel dimensions, so " "'slices' should be set") elif len(slices) != wcs.pixel_n_dim: raise ValueError("'slices' should have as many elements as WCS " "has pixel dimensions (should be {})" .format(wcs.pixel_n_dim)) is_fits_wcs = isinstance(wcs, WCS) or (isinstance(wcs, SlicedLowLevelWCS) and isinstance(wcs._wcs, WCS)) coord_meta = {} coord_meta['name'] = [] coord_meta['type'] = [] coord_meta['wrap'] = [] coord_meta['unit'] = [] coord_meta['visible'] = [] coord_meta['format_unit'] = [] for idx in range(wcs.world_n_dim): axis_type = wcs.world_axis_physical_types[idx] axis_unit = u.Unit(wcs.world_axis_units[idx]) coord_wrap = None format_unit = axis_unit coord_type = 'scalar' if axis_type is not None: axis_type_split = axis_type.split('.') if "pos.helioprojective.lon" in axis_type: coord_wrap = 180. format_unit = u.arcsec coord_type = "longitude" elif "pos.helioprojective.lat" in axis_type: format_unit = u.arcsec coord_type = "latitude" elif "pos.heliographic.stonyhurst.lon" in axis_type: coord_wrap = 180. format_unit = u.deg coord_type = "longitude" elif "pos.heliographic.stonyhurst.lat" in axis_type: format_unit = u.deg coord_type = "latitude" elif "pos.heliographic.carrington.lon" in axis_type: coord_wrap = 360. format_unit = u.deg coord_type = "longitude" elif "pos.heliographic.carrington.lat" in axis_type: format_unit = u.deg coord_type = "latitude" elif "pos" in axis_type_split: if "lon" in axis_type_split: coord_type = "longitude" elif "lat" in axis_type_split: coord_type = "latitude" elif "ra" in axis_type_split: coord_type = "longitude" format_unit = u.hourangle elif "dec" in axis_type_split: coord_type = "latitude" elif "alt" in axis_type_split: coord_type = "longitude" elif "az" in axis_type_split: coord_type = "latitude" elif "long" in axis_type_split: coord_type = "longitude" coord_meta['type'].append(coord_type) coord_meta['wrap'].append(coord_wrap) coord_meta['format_unit'].append(format_unit) coord_meta['unit'].append(axis_unit) # For FITS-WCS, for backward-compatibility, we need to make sure that we # provide aliases based on CTYPE for the name. if is_fits_wcs: name = [] if isinstance(wcs, WCS): name.append(wcs.wcs.ctype[idx].lower()) name.append(wcs.wcs.ctype[idx][:4].replace('-', '').lower()) elif isinstance(wcs, SlicedLowLevelWCS): name.append(wcs._wcs.wcs.ctype[wcs._world_keep[idx]].lower()) name.append(wcs._wcs.wcs.ctype[wcs._world_keep[idx]][:4].replace('-', '').lower()) if name[0] == name[1]: name = name[0:1] if axis_type: if axis_type not in name: name.insert(0, axis_type) if wcs.world_axis_names and wcs.world_axis_names[idx]: if wcs.world_axis_names[idx] not in name: name.append(wcs.world_axis_names[idx]) name = tuple(name) if len(name) > 1 else name[0] else: name = axis_type or '' if wcs.world_axis_names: name = (name, wcs.world_axis_names[idx]) if wcs.world_axis_names[idx] else name coord_meta['name'].append(name) coord_meta['default_axislabel_position'] = [''] * wcs.world_n_dim coord_meta['default_ticklabel_position'] = [''] * wcs.world_n_dim coord_meta['default_ticks_position'] = [''] * wcs.world_n_dim # If the world axis has a name use it, else display the world axis physical type. fallback_labels = [name[0] if isinstance(name, (list, tuple)) else name for name in coord_meta['name']] coord_meta['default_axis_label'] = [wcs.world_axis_names[i] or fallback_label for i, fallback_label in enumerate(fallback_labels)] transform_wcs, invert_xy, world_map = apply_slices(wcs, slices) transform = WCSPixel2WorldTransform(transform_wcs, invert_xy=invert_xy) for i in range(len(coord_meta['type'])): coord_meta['visible'].append(i in world_map) inv_all_corr = [False] * wcs.world_n_dim m = transform_wcs.axis_correlation_matrix.copy() if invert_xy: inv_all_corr = np.all(m, axis=1) m = m[:, ::-1] if frame_class is RectangularFrame: for i, spine_name in enumerate('bltr'): pos = np.nonzero(m[:, i % 2])[0] # If all the axes we have are correlated with each other and we # have inverted the axes, then we need to reverse the index so we # put the 'y' on the left. if inv_all_corr[i % 2]: pos = pos[::-1] if len(pos) > 0: index = world_map[pos[0]] coord_meta['default_axislabel_position'][index] = spine_name coord_meta['default_ticklabel_position'][index] = spine_name coord_meta['default_ticks_position'][index] = spine_name m[pos[0], :] = 0 # In the special and common case where the frame is rectangular and # we are dealing with 2-d WCS (after slicing), we show all ticks on # all axes for backward-compatibility. if len(world_map) == 2: for index in world_map: coord_meta['default_ticks_position'][index] = 'bltr' elif frame_class is RectangularFrame1D: derivs = np.abs(local_partial_pixel_derivatives(transform_wcs, *[0]*transform_wcs.pixel_n_dim, normalize_by_world=False))[:, 0] for i, spine_name in enumerate('bt'): # Here we are iterating over the correlated axes in world axis order. # We want to sort the correlated axes by their partial derivatives, # so we put the most rapidly changing world axis on the bottom. pos = np.nonzero(m[:, 0])[0] order = np.argsort(derivs[pos])[::-1] # Sort largest to smallest pos = pos[order] if len(pos) > 0: index = world_map[pos[0]] coord_meta['default_axislabel_position'][index] = spine_name coord_meta['default_ticklabel_position'][index] = spine_name coord_meta['default_ticks_position'][index] = spine_name m[pos[0], :] = 0 # In the special and common case where the frame is rectangular and # we are dealing with 2-d WCS (after slicing), we show all ticks on # all axes for backward-compatibility. if len(world_map) == 1: for index in world_map: coord_meta['default_ticks_position'][index] = 'bt' elif frame_class is EllipticalFrame: if 'longitude' in coord_meta['type']: lon_idx = coord_meta['type'].index('longitude') coord_meta['default_axislabel_position'][lon_idx] = 'h' coord_meta['default_ticklabel_position'][lon_idx] = 'h' coord_meta['default_ticks_position'][lon_idx] = 'h' if 'latitude' in coord_meta['type']: lat_idx = coord_meta['type'].index('latitude') coord_meta['default_axislabel_position'][lat_idx] = 'c' coord_meta['default_ticklabel_position'][lat_idx] = 'c' coord_meta['default_ticks_position'][lat_idx] = 'c' else: for index in range(len(coord_meta['type'])): if index in world_map: coord_meta['default_axislabel_position'][index] = frame_class.spine_names coord_meta['default_ticklabel_position'][index] = frame_class.spine_names coord_meta['default_ticks_position'][index] = frame_class.spine_names return transform, coord_meta def apply_slices(wcs, slices): """ Take the input WCS and slices and return a sliced WCS for the transform and a mapping of world axes in the sliced WCS to the input WCS. """ if isinstance(wcs, SlicedLowLevelWCS): world_keep = list(wcs._world_keep) else: world_keep = list(range(wcs.world_n_dim)) # world_map is the index of the world axis in the input WCS for a given # axis in the transform_wcs world_map = list(range(wcs.world_n_dim)) transform_wcs = wcs invert_xy = False if slices is not None: wcs_slice = list(slices) wcs_slice[wcs_slice.index("x")] = slice(None) if 'y' in slices: wcs_slice[wcs_slice.index("y")] = slice(None) invert_xy = slices.index('x') > slices.index('y') transform_wcs = SlicedLowLevelWCS(wcs, wcs_slice[::-1]) world_map = tuple(world_keep.index(i) for i in transform_wcs._world_keep) return transform_wcs, invert_xy, world_map def wcsapi_to_celestial_frame(wcs): for cls, _, kwargs, *_ in wcs.world_axis_object_classes.values(): if issubclass(cls, SkyCoord): return kwargs.get('frame', ICRS()) elif issubclass(cls, BaseCoordinateFrame): return cls(**kwargs) class WCSWorld2PixelTransform(CurvedTransform): """ WCS transformation from world to pixel coordinates """ has_inverse = True frame_in = None def __init__(self, wcs, invert_xy=False): super().__init__() if wcs.pixel_n_dim > 2: raise ValueError('Only pixel_n_dim =< 2 is supported') self.wcs = wcs self.invert_xy = invert_xy self.frame_in = wcsapi_to_celestial_frame(wcs) def __eq__(self, other): return (isinstance(other, type(self)) and self.wcs is other.wcs and self.invert_xy == other.invert_xy) @property def input_dims(self): return self.wcs.world_n_dim def transform(self, world): # Convert to a list of arrays world = list(world.T) if len(world) != self.wcs.world_n_dim: raise ValueError(f"Expected {self.wcs.world_n_dim} world coordinates, got {len(world)} ") if len(world[0]) == 0: pixel = np.zeros((0, 2)) else: pixel = self.wcs.world_to_pixel_values(*world) if self.invert_xy: pixel = pixel[::-1] pixel = np.array(pixel).T return pixel transform_non_affine = transform def inverted(self): """ Return the inverse of the transform """ return WCSPixel2WorldTransform(self.wcs, invert_xy=self.invert_xy) class WCSPixel2WorldTransform(CurvedTransform): """ WCS transformation from pixel to world coordinates """ has_inverse = True def __init__(self, wcs, invert_xy=False): super().__init__() if wcs.pixel_n_dim > 2: raise ValueError('Only pixel_n_dim =< 2 is supported') self.wcs = wcs self.invert_xy = invert_xy self.frame_out = wcsapi_to_celestial_frame(wcs) def __eq__(self, other): return (isinstance(other, type(self)) and self.wcs is other.wcs and self.invert_xy == other.invert_xy) @property def output_dims(self): return self.wcs.world_n_dim def transform(self, pixel): # Convert to a list of arrays pixel = list(pixel.T) if len(pixel) != self.wcs.pixel_n_dim: raise ValueError(f"Expected {self.wcs.pixel_n_dim} world coordinates, got {len(pixel)} ") if self.invert_xy: pixel = pixel[::-1] if len(pixel[0]) == 0: world = np.zeros((0, self.wcs.world_n_dim)) else: world = self.wcs.pixel_to_world_values(*pixel) if self.wcs.world_n_dim == 1: world = [world] world = np.array(world).T return world transform_non_affine = transform def inverted(self): """ Return the inverse of the transform """ return WCSWorld2PixelTransform(self.wcs, invert_xy=self.invert_xy)