"""Utility functions to use Python Array API compatible libraries. For the context about the Array API see: https://data-apis.org/array-api/latest/purpose_and_scope.html The SciPy use case of the Array API is described on the following page: https://data-apis.org/array-api/latest/use_cases.html#use-case-scipy """ from __future__ import annotations import os import warnings import numpy as np from scipy._lib import array_api_compat from scipy._lib.array_api_compat import ( is_array_api_obj, size, numpy as np_compat, ) __all__ = ['array_namespace', '_asarray', 'size'] # To enable array API and strict array-like input validation SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False) # To control the default device - for use in the test suite only SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu") _GLOBAL_CONFIG = { "SCIPY_ARRAY_API": SCIPY_ARRAY_API, "SCIPY_DEVICE": SCIPY_DEVICE, } def compliance_scipy(arrays): """Raise exceptions on known-bad subclasses. The following subclasses are not supported and raise and error: - `numpy.ma.MaskedArray` - `numpy.matrix` - NumPy arrays which do not have a boolean or numerical dtype - Any array-like which is neither array API compatible nor coercible by NumPy - Any array-like which is coerced by NumPy to an unsupported dtype """ for i in range(len(arrays)): array = arrays[i] if isinstance(array, np.ma.MaskedArray): raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.") elif isinstance(array, np.matrix): raise TypeError("Inputs of type `numpy.matrix` are not supported.") if isinstance(array, (np.ndarray, np.generic)): dtype = array.dtype if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)): raise TypeError(f"An argument has dtype `{dtype!r}`; " f"only boolean and numerical dtypes are supported.") elif not is_array_api_obj(array): try: array = np.asanyarray(array) except TypeError: raise TypeError("An argument is neither array API compatible nor " "coercible by NumPy.") dtype = array.dtype if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)): message = ( f"An argument was coerced to an unsupported dtype `{dtype!r}`; " f"only boolean and numerical dtypes are supported." ) raise TypeError(message) arrays[i] = array return arrays def _check_finite(array, xp): """Check for NaNs or Infs.""" msg = "array must not contain infs or NaNs" try: if not xp.all(xp.isfinite(array)): raise ValueError(msg) except TypeError: raise ValueError(msg) def array_namespace(*arrays): """Get the array API compatible namespace for the arrays xs. Parameters ---------- *arrays : sequence of array_like Arrays used to infer the common namespace. Returns ------- namespace : module Common namespace. Notes ----- Thin wrapper around `array_api_compat.array_namespace`. 1. Check for the global switch: SCIPY_ARRAY_API. This can also be accessed dynamically through ``_GLOBAL_CONFIG['SCIPY_ARRAY_API']``. 2. `compliance_scipy` raise exceptions on known-bad subclasses. See its definition for more details. When the global switch is False, it defaults to the `numpy` namespace. In that case, there is no compliance check. This is a convenience to ease the adoption. Otherwise, arrays must comply with the new rules. """ if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]: # here we could wrap the namespace if needed return np_compat arrays = [array for array in arrays if array is not None] arrays = compliance_scipy(arrays) return array_api_compat.array_namespace(*arrays) def _asarray( array, dtype=None, order=None, copy=None, *, xp=None, check_finite=False ): """SciPy-specific replacement for `np.asarray` with `order` and `check_finite`. Memory layout parameter `order` is not exposed in the Array API standard. `order` is only enforced if the input array implementation is NumPy based, otherwise `order` is just silently ignored. `check_finite` is also not a keyword in the array API standard; included here for convenience rather than that having to be a separate function call inside SciPy functions. """ if xp is None: xp = array_namespace(array) if xp.__name__ in {"numpy", "scipy._lib.array_api_compat.numpy"}: # Use NumPy API to support order if copy is True: array = np.array(array, order=order, dtype=dtype) else: array = np.asarray(array, order=order, dtype=dtype) # At this point array is a NumPy ndarray. We convert it to an array # container that is consistent with the input's namespace. array = xp.asarray(array) else: try: array = xp.asarray(array, dtype=dtype, copy=copy) except TypeError: coerced_xp = array_namespace(xp.asarray(3)) array = coerced_xp.asarray(array, dtype=dtype, copy=copy) if check_finite: _check_finite(array, xp) return array def atleast_nd(x, *, ndim, xp=None): """Recursively expand the dimension to have at least `ndim`.""" if xp is None: xp = array_namespace(x) x = xp.asarray(x) if x.ndim < ndim: x = xp.expand_dims(x, axis=0) x = atleast_nd(x, ndim=ndim, xp=xp) return x def copy(x, *, xp=None): """ Copies an array. Parameters ---------- x : array xp : array_namespace Returns ------- copy : array Copied array Notes ----- This copy function does not offer all the semantics of `np.copy`, i.e. the `subok` and `order` keywords are not used. """ # Note: xp.asarray fails if xp is numpy. if xp is None: xp = array_namespace(x) return _asarray(x, copy=True, xp=xp) def is_numpy(xp): return xp.__name__ in ('numpy', 'scipy._lib.array_api_compat.numpy') def is_cupy(xp): return xp.__name__ in ('cupy', 'scipy._lib.array_api_compat.cupy') def is_torch(xp): return xp.__name__ in ('torch', 'scipy._lib.array_api_compat.torch') def _strict_check(actual, desired, xp, check_namespace=True, check_dtype=True, check_shape=True): __tracebackhide__ = True # Hide traceback for py.test if check_namespace: _assert_matching_namespace(actual, desired) desired = xp.asarray(desired) if check_dtype: _msg = "dtypes do not match.\nActual: {actual.dtype}\nDesired: {desired.dtype}" assert actual.dtype == desired.dtype, _msg if check_shape: _msg = "Shapes do not match.\nActual: {actual.shape}\nDesired: {desired.shape}" assert actual.shape == desired.shape, _msg _check_scalar(actual, desired, xp) desired = xp.broadcast_to(desired, actual.shape) return desired def _assert_matching_namespace(actual, desired): __tracebackhide__ = True # Hide traceback for py.test actual = actual if isinstance(actual, tuple) else (actual,) desired_space = array_namespace(desired) for arr in actual: arr_space = array_namespace(arr) _msg = (f"Namespaces do not match.\n" f"Actual: {arr_space.__name__}\n" f"Desired: {desired_space.__name__}") assert arr_space == desired_space, _msg def _check_scalar(actual, desired, xp): __tracebackhide__ = True # Hide traceback for py.test # Shape check alone is sufficient unless desired.shape == (). Also, # only NumPy distinguishes between scalars and arrays. if desired.shape != () or not is_numpy(xp): return # We want to follow the conventions of the `xp` library. Libraries like # NumPy, for which `np.asarray(0)[()]` returns a scalar, tend to return # a scalar even when a 0D array might be more appropriate: # import numpy as np # np.mean([1, 2, 3]) # scalar, not 0d array # np.asarray(0)*2 # scalar, not 0d array # np.sin(np.asarray(0)) # scalar, not 0d array # Libraries like CuPy, for which `cp.asarray(0)[()]` returns a 0D array, # tend to return a 0D array in scenarios like those above. # Therefore, regardless of whether the developer provides a scalar or 0D # array for `desired`, we would typically want the type of `actual` to be # the type of `desired[()]`. If the developer wants to override this # behavior, they can set `check_shape=False`. desired = desired[()] _msg = f"Types do not match:\n Actual: {type(actual)}\n Desired: {type(desired)}" assert (xp.isscalar(actual) and xp.isscalar(desired) or (not xp.isscalar(actual) and not xp.isscalar(desired))), _msg def xp_assert_equal(actual, desired, check_namespace=True, check_dtype=True, check_shape=True, err_msg='', xp=None): __tracebackhide__ = True # Hide traceback for py.test if xp is None: xp = array_namespace(actual) desired = _strict_check(actual, desired, xp, check_namespace=check_namespace, check_dtype=check_dtype, check_shape=check_shape) if is_cupy(xp): return xp.testing.assert_array_equal(actual, desired, err_msg=err_msg) elif is_torch(xp): # PyTorch recommends using `rtol=0, atol=0` like this # to test for exact equality err_msg = None if err_msg == '' else err_msg return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True, check_dtype=False, msg=err_msg) return np.testing.assert_array_equal(actual, desired, err_msg=err_msg) def xp_assert_close(actual, desired, rtol=1e-07, atol=0, check_namespace=True, check_dtype=True, check_shape=True, err_msg='', xp=None): __tracebackhide__ = True # Hide traceback for py.test if xp is None: xp = array_namespace(actual) desired = _strict_check(actual, desired, xp, check_namespace=check_namespace, check_dtype=check_dtype, check_shape=check_shape) if is_cupy(xp): return xp.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, err_msg=err_msg) elif is_torch(xp): err_msg = None if err_msg == '' else err_msg return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol, equal_nan=True, check_dtype=False, msg=err_msg) return np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, err_msg=err_msg) def xp_assert_less(actual, desired, check_namespace=True, check_dtype=True, check_shape=True, err_msg='', verbose=True, xp=None): __tracebackhide__ = True # Hide traceback for py.test if xp is None: xp = array_namespace(actual) desired = _strict_check(actual, desired, xp, check_namespace=check_namespace, check_dtype=check_dtype, check_shape=check_shape) if is_cupy(xp): return xp.testing.assert_array_less(actual, desired, err_msg=err_msg, verbose=verbose) elif is_torch(xp): if actual.device.type != 'cpu': actual = actual.cpu() if desired.device.type != 'cpu': desired = desired.cpu() return np.testing.assert_array_less(actual, desired, err_msg=err_msg, verbose=verbose) def cov(x, *, xp=None): if xp is None: xp = array_namespace(x) X = copy(x, xp=xp) dtype = xp.result_type(X, xp.float64) X = atleast_nd(X, ndim=2, xp=xp) X = xp.asarray(X, dtype=dtype) avg = xp.mean(X, axis=1) fact = X.shape[1] - 1 if fact <= 0: warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) fact = 0.0 X -= avg[:, None] X_T = X.T if xp.isdtype(X_T.dtype, 'complex floating'): X_T = xp.conj(X_T) c = X @ X_T c /= fact axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) return xp.squeeze(c, axis=axes) def xp_unsupported_param_msg(param): return f'Providing {param!r} is only supported for numpy arrays.' def is_complex(x, xp): return xp.isdtype(x.dtype, 'complex floating')