"""Coders for strings.""" from functools import partial import numpy as np from ..core import indexing from ..core.pycompat import is_duck_dask_array from ..core.variable import Variable from .variables import ( VariableCoder, lazy_elemwise_func, pop_to, safe_setitem, unpack_for_decoding, unpack_for_encoding, ) def create_vlen_dtype(element_type): # based on h5py.special_dtype return np.dtype("O", metadata={"element_type": element_type}) def check_vlen_dtype(dtype): if dtype.kind != "O" or dtype.metadata is None: return None else: return dtype.metadata.get("element_type") def is_unicode_dtype(dtype): return dtype.kind == "U" or check_vlen_dtype(dtype) == str def is_bytes_dtype(dtype): return dtype.kind == "S" or check_vlen_dtype(dtype) == bytes class EncodedStringCoder(VariableCoder): """Transforms between unicode strings and fixed-width UTF-8 bytes.""" def __init__(self, allows_unicode=True): self.allows_unicode = allows_unicode def encode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_encoding(variable) contains_unicode = is_unicode_dtype(data.dtype) encode_as_char = encoding.get("dtype") == "S1" if encode_as_char: del encoding["dtype"] # no longer relevant if contains_unicode and (encode_as_char or not self.allows_unicode): if "_FillValue" in attrs: raise NotImplementedError( "variable {!r} has a _FillValue specified, but " "_FillValue is not yet supported on unicode strings: " "https://github.com/pydata/xarray/issues/1647".format(name) ) string_encoding = encoding.pop("_Encoding", "utf-8") safe_setitem(attrs, "_Encoding", string_encoding, name=name) # TODO: figure out how to handle this in a lazy way with dask data = encode_string_array(data, string_encoding) return Variable(dims, data, attrs, encoding) def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) if "_Encoding" in attrs: string_encoding = pop_to(attrs, encoding, "_Encoding") func = partial(decode_bytes_array, encoding=string_encoding) data = lazy_elemwise_func(data, func, np.dtype(object)) return Variable(dims, data, attrs, encoding) def decode_bytes_array(bytes_array, encoding="utf-8"): # This is faster than using np.char.decode() or np.vectorize() bytes_array = np.asarray(bytes_array) decoded = [x.decode(encoding) for x in bytes_array.ravel()] return np.array(decoded, dtype=object).reshape(bytes_array.shape) def encode_string_array(string_array, encoding="utf-8"): string_array = np.asarray(string_array) encoded = [x.encode(encoding) for x in string_array.ravel()] return np.array(encoded, dtype=bytes).reshape(string_array.shape) def ensure_fixed_length_bytes(var): """Ensure that a variable with vlen bytes is converted to fixed width.""" dims, data, attrs, encoding = unpack_for_encoding(var) if check_vlen_dtype(data.dtype) == bytes: # TODO: figure out how to handle this with dask data = np.asarray(data, dtype=np.string_) return Variable(dims, data, attrs, encoding) class CharacterArrayCoder(VariableCoder): """Transforms between arrays containing bytes and character arrays.""" def encode(self, variable, name=None): variable = ensure_fixed_length_bytes(variable) dims, data, attrs, encoding = unpack_for_encoding(variable) if data.dtype.kind == "S" and encoding.get("dtype") is not str: data = bytes_to_char(data) if "char_dim_name" in encoding.keys(): char_dim_name = encoding.pop("char_dim_name") else: char_dim_name = f"string{data.shape[-1]}" dims = dims + (char_dim_name,) return Variable(dims, data, attrs, encoding) def decode(self, variable, name=None): dims, data, attrs, encoding = unpack_for_decoding(variable) if data.dtype == "S1" and dims: encoding["char_dim_name"] = dims[-1] dims = dims[:-1] data = char_to_bytes(data) return Variable(dims, data, attrs, encoding) def bytes_to_char(arr): """Convert numpy/dask arrays from fixed width bytes to characters.""" if arr.dtype.kind != "S": raise ValueError("argument must have a fixed-width bytes dtype") if is_duck_dask_array(arr): import dask.array as da return da.map_blocks( _numpy_bytes_to_char, arr, dtype="S1", chunks=arr.chunks + ((arr.dtype.itemsize,)), new_axis=[arr.ndim], ) return _numpy_bytes_to_char(arr) def _numpy_bytes_to_char(arr): """Like netCDF4.stringtochar, but faster and more flexible.""" # ensure the array is contiguous arr = np.array(arr, copy=False, order="C", dtype=np.string_) return arr.reshape(arr.shape + (1,)).view("S1") def char_to_bytes(arr): """Convert numpy/dask arrays from characters to fixed width bytes.""" if arr.dtype != "S1": raise ValueError("argument must have dtype='S1'") if not arr.ndim: # no dimension to concatenate along return arr size = arr.shape[-1] if not size: # can't make an S0 dtype return np.zeros(arr.shape[:-1], dtype=np.string_) if is_duck_dask_array(arr): import dask.array as da if len(arr.chunks[-1]) > 1: raise ValueError( "cannot stacked dask character array with " "multiple chunks in the last dimension: {}".format(arr) ) dtype = np.dtype("S" + str(arr.shape[-1])) return da.map_blocks( _numpy_char_to_bytes, arr, dtype=dtype, chunks=arr.chunks[:-1], drop_axis=[arr.ndim - 1], ) else: return StackedBytesArray(arr) def _numpy_char_to_bytes(arr): """Like netCDF4.chartostring, but faster and more flexible.""" # based on: http://stackoverflow.com/a/10984878/809705 arr = np.array(arr, copy=False, order="C") dtype = "S" + str(arr.shape[-1]) return arr.view(dtype).reshape(arr.shape[:-1]) class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin): """Wrapper around array-like objects to create a new indexable object where values, when accessed, are automatically stacked along the last dimension. >>> indexer = indexing.BasicIndexer((slice(None),)) >>> StackedBytesArray(np.array(["a", "b", "c"], dtype="S1"))[indexer] array(b'abc', dtype='|S3') """ def __init__(self, array): """ Parameters ---------- array : array-like Original array of values to wrap. """ if array.dtype != "S1": raise ValueError( "can only use StackedBytesArray if argument has dtype='S1'" ) self.array = indexing.as_indexable(array) @property def dtype(self): return np.dtype("S" + str(self.array.shape[-1])) @property def shape(self): return self.array.shape[:-1] def __repr__(self): return "{}({!r})".format(type(self).__name__, self.array) def __getitem__(self, key): # require slicing the last dimension completely key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim)) if key.tuple[-1] != slice(None): raise IndexError("too many indices") return _numpy_char_to_bytes(self.array[key])