import numbers from collections import defaultdict import numpy as np from astropy.utils import isiterable from astropy.utils.decorators import lazyproperty from ..low_level_api import BaseLowLevelWCS from .base import BaseWCSWrapper __all__ = ['sanitize_slices', 'SlicedLowLevelWCS'] def sanitize_slices(slices, ndim): """ Given a slice as input sanitise it to an easier to parse format.format This function returns a list ``ndim`` long containing slice objects (or ints). """ if not isinstance(slices, (tuple, list)): # We just have a single int slices = (slices,) if len(slices) > ndim: raise ValueError( f"The dimensionality of the specified slice {slices} can not be greater " f"than the dimensionality ({ndim}) of the wcs.") if any((isiterable(s) for s in slices)): raise IndexError("This slice is invalid, only integer or range slices are supported.") slices = list(slices) if Ellipsis in slices: if slices.count(Ellipsis) > 1: raise IndexError("an index can only have a single ellipsis ('...')") # Replace the Ellipsis with the correct number of slice(None)s e_ind = slices.index(Ellipsis) slices.remove(Ellipsis) n_e = ndim - len(slices) for i in range(n_e): ind = e_ind + i slices.insert(ind, slice(None)) for i in range(ndim): if i < len(slices): slc = slices[i] if isinstance(slc, slice): if slc.step and slc.step != 1: raise IndexError("Slicing WCS with a step is not supported.") elif not isinstance(slc, numbers.Integral): raise IndexError("Only integer or range slices are accepted.") else: slices.append(slice(None)) return slices def combine_slices(slice1, slice2): """ Given two slices that can be applied to a 1-d array, find the resulting slice that corresponds to the combination of both slices. We assume that slice2 can be an integer, but slice1 cannot. """ if isinstance(slice1, slice) and slice1.step is not None: raise ValueError('Only slices with steps of 1 are supported') if isinstance(slice2, slice) and slice2.step is not None: raise ValueError('Only slices with steps of 1 are supported') if isinstance(slice2, numbers.Integral): if slice1.start is None: return slice2 else: return slice2 + slice1.start if slice1.start is None: if slice1.stop is None: return slice2 else: if slice2.stop is None: return slice(slice2.start, slice1.stop) else: return slice(slice2.start, min(slice1.stop, slice2.stop)) else: if slice2.start is None: start = slice1.start else: start = slice1.start + slice2.start if slice2.stop is None: stop = slice1.stop else: if slice1.start is None: stop = slice2.stop else: stop = slice2.stop + slice1.start if slice1.stop is not None: stop = min(slice1.stop, stop) return slice(start, stop) class SlicedLowLevelWCS(BaseWCSWrapper): """ A Low Level WCS wrapper which applies an array slice to a WCS. This class does not modify the underlying WCS object and can therefore drop coupled dimensions as it stores which pixel and world dimensions have been sliced out (or modified) in the underlying WCS and returns the modified results on all the Low Level WCS methods. Parameters ---------- wcs : `~astropy.wcs.wcsapi.BaseLowLevelWCS` The WCS to slice. slices : `slice` or `tuple` or `int` A valid array slice to apply to the WCS. """ def __init__(self, wcs, slices): slices = sanitize_slices(slices, wcs.pixel_n_dim) if isinstance(wcs, SlicedLowLevelWCS): # Here we combine the current slices with the previous slices # to avoid ending up with many nested WCSes self._wcs = wcs._wcs slices_original = wcs._slices_array.copy() for ipixel in range(wcs.pixel_n_dim): ipixel_orig = wcs._wcs.pixel_n_dim - 1 - wcs._pixel_keep[ipixel] ipixel_new = wcs.pixel_n_dim - 1 - ipixel slices_original[ipixel_orig] = combine_slices(slices_original[ipixel_orig], slices[ipixel_new]) self._slices_array = slices_original else: self._wcs = wcs self._slices_array = slices self._slices_pixel = self._slices_array[::-1] # figure out which pixel dimensions have been kept, then use axis correlation # matrix to figure out which world dims are kept self._pixel_keep = np.nonzero([not isinstance(self._slices_pixel[ip], numbers.Integral) for ip in range(self._wcs.pixel_n_dim)])[0] # axis_correlation_matrix[world, pixel] self._world_keep = np.nonzero( self._wcs.axis_correlation_matrix[:, self._pixel_keep].any(axis=1))[0] if len(self._pixel_keep) == 0 or len(self._world_keep) == 0: raise ValueError("Cannot slice WCS: the resulting WCS should have " "at least one pixel and one world dimension.") @lazyproperty def dropped_world_dimensions(self): """ Information describing the dropped world dimensions. """ world_coords = self._pixel_to_world_values_all(*[0]*len(self._pixel_keep)) dropped_info = defaultdict(list) for i in range(self._wcs.world_n_dim): if i in self._world_keep: continue if "world_axis_object_classes" not in dropped_info: dropped_info["world_axis_object_classes"] = dict() wao_classes = self._wcs.world_axis_object_classes wao_components = self._wcs.world_axis_object_components dropped_info["value"].append(world_coords[i]) dropped_info["world_axis_names"].append(self._wcs.world_axis_names[i]) dropped_info["world_axis_physical_types"].append(self._wcs.world_axis_physical_types[i]) dropped_info["world_axis_units"].append(self._wcs.world_axis_units[i]) dropped_info["world_axis_object_components"].append(wao_components[i]) dropped_info["world_axis_object_classes"].update(dict( filter( lambda x: x[0] == wao_components[i][0], wao_classes.items() ) )) dropped_info["serialized_classes"] = self.serialized_classes return dict(dropped_info) @property def pixel_n_dim(self): return len(self._pixel_keep) @property def world_n_dim(self): return len(self._world_keep) @property def world_axis_physical_types(self): return [self._wcs.world_axis_physical_types[i] for i in self._world_keep] @property def world_axis_units(self): return [self._wcs.world_axis_units[i] for i in self._world_keep] @property def pixel_axis_names(self): return [self._wcs.pixel_axis_names[i] for i in self._pixel_keep] @property def world_axis_names(self): return [self._wcs.world_axis_names[i] for i in self._world_keep] def _pixel_to_world_values_all(self, *pixel_arrays): pixel_arrays = tuple(map(np.asanyarray, pixel_arrays)) pixel_arrays_new = [] ipix_curr = -1 for ipix in range(self._wcs.pixel_n_dim): if isinstance(self._slices_pixel[ipix], numbers.Integral): pixel_arrays_new.append(self._slices_pixel[ipix]) else: ipix_curr += 1 if self._slices_pixel[ipix].start is not None: pixel_arrays_new.append(pixel_arrays[ipix_curr] + self._slices_pixel[ipix].start) else: pixel_arrays_new.append(pixel_arrays[ipix_curr]) pixel_arrays_new = np.broadcast_arrays(*pixel_arrays_new) return self._wcs.pixel_to_world_values(*pixel_arrays_new) def pixel_to_world_values(self, *pixel_arrays): world_arrays = self._pixel_to_world_values_all(*pixel_arrays) # Detect the case of a length 0 array if isinstance(world_arrays, np.ndarray) and not world_arrays.shape: return world_arrays if self._wcs.world_n_dim > 1: # Select the dimensions of the original WCS we are keeping. world_arrays = [world_arrays[iw] for iw in self._world_keep] # If there is only one world dimension (after slicing) we shouldn't return a tuple. if self.world_n_dim == 1: world_arrays = world_arrays[0] return world_arrays def world_to_pixel_values(self, *world_arrays): world_arrays = tuple(map(np.asanyarray, world_arrays)) world_arrays_new = [] iworld_curr = -1 for iworld in range(self._wcs.world_n_dim): if iworld in self._world_keep: iworld_curr += 1 world_arrays_new.append(world_arrays[iworld_curr]) else: world_arrays_new.append(1.) world_arrays_new = np.broadcast_arrays(*world_arrays_new) pixel_arrays = list(self._wcs.world_to_pixel_values(*world_arrays_new)) for ipixel in range(self._wcs.pixel_n_dim): if isinstance(self._slices_pixel[ipixel], slice) and self._slices_pixel[ipixel].start is not None: pixel_arrays[ipixel] -= self._slices_pixel[ipixel].start # Detect the case of a length 0 array if isinstance(pixel_arrays, np.ndarray) and not pixel_arrays.shape: return pixel_arrays pixel = tuple(pixel_arrays[ip] for ip in self._pixel_keep) if self.pixel_n_dim == 1 and self._wcs.pixel_n_dim > 1: pixel = pixel[0] return pixel @property def world_axis_object_components(self): return [self._wcs.world_axis_object_components[idx] for idx in self._world_keep] @property def world_axis_object_classes(self): keys_keep = [item[0] for item in self.world_axis_object_components] return dict([item for item in self._wcs.world_axis_object_classes.items() if item[0] in keys_keep]) @property def array_shape(self): if self._wcs.array_shape: return np.broadcast_to(0, self._wcs.array_shape)[tuple(self._slices_array)].shape @property def pixel_shape(self): if self.array_shape: return tuple(self.array_shape[::-1]) @property def pixel_bounds(self): if self._wcs.pixel_bounds is None: return bounds = [] for idx in self._pixel_keep: if self._slices_pixel[idx].start is None: bounds.append(self._wcs.pixel_bounds[idx]) else: imin, imax = self._wcs.pixel_bounds[idx] start = self._slices_pixel[idx].start bounds.append((imin - start, imax - start)) return tuple(bounds) @property def axis_correlation_matrix(self): return self._wcs.axis_correlation_matrix[self._world_keep][:, self._pixel_keep]