from functools import partial from itertools import product import numpy as np from tlz import curry from ..base import tokenize from ..blockwise import blockwise as core_blockwise from ..layers import ArrayChunkShapeDep from ..utils import funcname from .core import Array, normalize_chunks from .utils import meta_from_array def _parse_wrap_args(func, args, kwargs, shape): if isinstance(shape, np.ndarray): shape = shape.tolist() if not isinstance(shape, (tuple, list)): shape = (shape,) name = kwargs.pop("name", None) chunks = kwargs.pop("chunks", "auto") dtype = kwargs.pop("dtype", None) if dtype is None: dtype = func(shape, *args, **kwargs).dtype dtype = np.dtype(dtype) chunks = normalize_chunks(chunks, shape, dtype=dtype) name = name or funcname(func) + "-" + tokenize( func, shape, chunks, dtype, args, kwargs ) return { "shape": shape, "dtype": dtype, "kwargs": kwargs, "chunks": chunks, "name": name, } def wrap_func_shape_as_first_arg(func, *args, **kwargs): """ Transform np creation function into blocked version """ if "shape" not in kwargs: shape, args = args[0], args[1:] else: shape = kwargs.pop("shape") if isinstance(shape, Array): raise TypeError( "Dask array input not supported. " "Please use tuple, list, or a 1D numpy array instead." ) parsed = _parse_wrap_args(func, args, kwargs, shape) shape = parsed["shape"] dtype = parsed["dtype"] chunks = parsed["chunks"] name = parsed["name"] kwargs = parsed["kwargs"] func = partial(func, dtype=dtype, **kwargs) out_ind = dep_ind = tuple(range(len(shape))) graph = core_blockwise( func, name, out_ind, ArrayChunkShapeDep(chunks), dep_ind, numblocks={}, ) return Array(graph, name, chunks, dtype=dtype, meta=kwargs.get("meta", None)) def wrap_func_like(func, *args, **kwargs): """ Transform np creation function into blocked version """ x = args[0] meta = meta_from_array(x) shape = kwargs.get("shape", x.shape) parsed = _parse_wrap_args(func, args, kwargs, shape) shape = parsed["shape"] dtype = parsed["dtype"] chunks = parsed["chunks"] name = parsed["name"] kwargs = parsed["kwargs"] keys = product([name], *[range(len(bd)) for bd in chunks]) shapes = product(*chunks) shapes = list(shapes) kw = [kwargs for _ in shapes] for i, s in enumerate(list(shapes)): kw[i]["shape"] = s vals = ((partial(func, dtype=dtype, **k),) + args for (k, s) in zip(kw, shapes)) dsk = dict(zip(keys, vals)) return Array(dsk, name, chunks, meta=meta.astype(dtype)) @curry def wrap(wrap_func, func, func_like=None, **kwargs): if func_like is None: f = partial(wrap_func, func, **kwargs) else: f = partial(wrap_func, func_like, **kwargs) template = """ Blocked variant of %(name)s Follows the signature of %(name)s exactly except that it also features optional keyword arguments ``chunks: int, tuple, or dict`` and ``name: str``. Original signature follows below. """ if func.__doc__ is not None: f.__doc__ = template % {"name": func.__name__} + func.__doc__ f.__name__ = "blocked_" + func.__name__ return f w = wrap(wrap_func_shape_as_first_arg) @curry def _broadcast_trick_inner(func, shape, meta=(), *args, **kwargs): # cupy-specific hack. numpy is happy with hardcoded shape=(). null_shape = () if shape == () else 1 return np.broadcast_to(func(meta, shape=null_shape, *args, **kwargs), shape) def broadcast_trick(func): """ Provide a decorator to wrap common numpy function with a broadcast trick. Dask arrays are currently immutable; thus when we know an array is uniform, we can replace the actual data by a single value and have all elements point to it, thus reducing the size. >>> x = np.broadcast_to(1, (100,100,100)) >>> x.base.nbytes 8 Those array are not only more efficient locally, but dask serialisation is aware of the _real_ size of those array and thus can send them around efficiently and schedule accordingly. Note that those array are read-only and numpy will refuse to assign to them, so should be safe. """ inner = _broadcast_trick_inner(func) inner.__doc__ = func.__doc__ inner.__name__ = func.__name__ return inner ones = w(broadcast_trick(np.ones_like), dtype="f8") zeros = w(broadcast_trick(np.zeros_like), dtype="f8") empty = w(broadcast_trick(np.empty_like), dtype="f8") w_like = wrap(wrap_func_like) empty_like = w_like(np.empty, func_like=np.empty_like) # full and full_like require special casing due to argument check on fill_value # Generate wrapped functions only once _full = w(broadcast_trick(np.full_like)) _full_like = w_like(np.full, func_like=np.full_like) # workaround for numpy doctest failure: https://github.com/numpy/numpy/pull/17472 _full.__doc__ = _full.__doc__.replace( "array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])", "array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1])", ) def full(shape, fill_value, *args, **kwargs): # np.isscalar has somewhat strange behavior: # https://docs.scipy.org/doc/numpy/reference/generated/numpy.isscalar.html if np.ndim(fill_value) != 0: raise ValueError( f"fill_value must be scalar. Received {type(fill_value).__name__} instead." ) if "dtype" not in kwargs: if hasattr(fill_value, "dtype"): kwargs["dtype"] = fill_value.dtype else: kwargs["dtype"] = type(fill_value) return _full(shape=shape, fill_value=fill_value, *args, **kwargs) def full_like(a, fill_value, *args, **kwargs): if np.ndim(fill_value) != 0: raise ValueError( f"fill_value must be scalar. Received {type(fill_value).__name__} instead." ) return _full_like( a=a, fill_value=fill_value, *args, **kwargs, ) full.__doc__ = _full.__doc__ full_like.__doc__ = _full_like.__doc__