import typing import types import inspect import functools from . import _uarray import copyreg import pickle import contextlib from ._uarray import ( # type: ignore BackendNotImplementedError, _Function, _SkipBackendContext, _SetBackendContext, _BackendState, ) __all__ = [ "set_backend", "set_global_backend", "skip_backend", "register_backend", "determine_backend", "determine_backend_multi", "clear_backends", "create_multimethod", "generate_multimethod", "_Function", "BackendNotImplementedError", "Dispatchable", "wrap_single_convertor", "wrap_single_convertor_instance", "all_of_type", "mark_as", "set_state", "get_state", "reset_state", "_BackendState", "_SkipBackendContext", "_SetBackendContext", ] ArgumentExtractorType = typing.Callable[..., tuple["Dispatchable", ...]] ArgumentReplacerType = typing.Callable[ [tuple, dict, tuple], tuple[tuple, dict] ] def unpickle_function(mod_name, qname, self_): import importlib try: module = importlib.import_module(mod_name) qname = qname.split(".") func = module for q in qname: func = getattr(func, q) if self_ is not None: func = types.MethodType(func, self_) return func except (ImportError, AttributeError) as e: from pickle import UnpicklingError raise UnpicklingError from e def pickle_function(func): mod_name = getattr(func, "__module__", None) qname = getattr(func, "__qualname__", None) self_ = getattr(func, "__self__", None) try: test = unpickle_function(mod_name, qname, self_) except pickle.UnpicklingError: test = None if test is not func: raise pickle.PicklingError( f"Can't pickle {func}: it's not the same object as {test}" ) return unpickle_function, (mod_name, qname, self_) def pickle_state(state): return _uarray._BackendState._unpickle, state._pickle() def pickle_set_backend_context(ctx): return _SetBackendContext, ctx._pickle() def pickle_skip_backend_context(ctx): return _SkipBackendContext, ctx._pickle() copyreg.pickle(_Function, pickle_function) copyreg.pickle(_uarray._BackendState, pickle_state) copyreg.pickle(_SetBackendContext, pickle_set_backend_context) copyreg.pickle(_SkipBackendContext, pickle_skip_backend_context) def get_state(): """ Returns an opaque object containing the current state of all the backends. Can be used for synchronization between threads/processes. See Also -------- set_state Sets the state returned by this function. """ return _uarray.get_state() @contextlib.contextmanager def reset_state(): """ Returns a context manager that resets all state once exited. See Also -------- set_state Context manager that sets the backend state. get_state Gets a state to be set by this context manager. """ with set_state(get_state()): yield @contextlib.contextmanager def set_state(state): """ A context manager that sets the state of the backends to one returned by :obj:`get_state`. See Also -------- get_state Gets a state to be set by this context manager. """ # noqa: E501 old_state = get_state() _uarray.set_state(state) try: yield finally: _uarray.set_state(old_state, True) def create_multimethod(*args, **kwargs): """ Creates a decorator for generating multimethods. This function creates a decorator that can be used with an argument extractor in order to generate a multimethod. Other than for the argument extractor, all arguments are passed on to :obj:`generate_multimethod`. See Also -------- generate_multimethod Generates a multimethod. """ def wrapper(a): return generate_multimethod(a, *args, **kwargs) return wrapper def generate_multimethod( argument_extractor: ArgumentExtractorType, argument_replacer: ArgumentReplacerType, domain: str, default: typing.Optional[typing.Callable] = None, ): """ Generates a multimethod. Parameters ---------- argument_extractor : ArgumentExtractorType A callable which extracts the dispatchable arguments. Extracted arguments should be marked by the :obj:`Dispatchable` class. It has the same signature as the desired multimethod. argument_replacer : ArgumentReplacerType A callable with the signature (args, kwargs, dispatchables), which should also return an (args, kwargs) pair with the dispatchables replaced inside the args/kwargs. domain : str A string value indicating the domain of this multimethod. default: Optional[Callable], optional The default implementation of this multimethod, where ``None`` (the default) specifies there is no default implementation. Examples -------- In this example, ``a`` is to be dispatched over, so we return it, while marking it as an ``int``. The trailing comma is needed because the args have to be returned as an iterable. >>> def override_me(a, b): ... return Dispatchable(a, int), Next, we define the argument replacer that replaces the dispatchables inside args/kwargs with the supplied ones. >>> def override_replacer(args, kwargs, dispatchables): ... return (dispatchables[0], args[1]), {} Next, we define the multimethod. >>> overridden_me = generate_multimethod( ... override_me, override_replacer, "ua_examples" ... ) Notice that there's no default implementation, unless you supply one. >>> overridden_me(1, "a") Traceback (most recent call last): ... uarray.BackendNotImplementedError: ... >>> overridden_me2 = generate_multimethod( ... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y) ... ) >>> overridden_me2(1, "a") (1, 'a') See Also -------- uarray See the module documentation for how to override the method by creating backends. """ kw_defaults, arg_defaults, opts = get_defaults(argument_extractor) ua_func = _Function( argument_extractor, argument_replacer, domain, arg_defaults, kw_defaults, default, ) return functools.update_wrapper(ua_func, argument_extractor) def set_backend(backend, coerce=False, only=False): """ A context manager that sets the preferred backend. Parameters ---------- backend The backend to set. coerce Whether or not to coerce to a specific backend's types. Implies ``only``. only Whether or not this should be the last backend to try. See Also -------- skip_backend: A context manager that allows skipping of backends. set_global_backend: Set a single, global backend for a domain. """ try: return backend.__ua_cache__["set", coerce, only] except AttributeError: backend.__ua_cache__ = {} except KeyError: pass ctx = _SetBackendContext(backend, coerce, only) backend.__ua_cache__["set", coerce, only] = ctx return ctx def skip_backend(backend): """ A context manager that allows one to skip a given backend from processing entirely. This allows one to use another backend's code in a library that is also a consumer of the same backend. Parameters ---------- backend The backend to skip. See Also -------- set_backend: A context manager that allows setting of backends. set_global_backend: Set a single, global backend for a domain. """ try: return backend.__ua_cache__["skip"] except AttributeError: backend.__ua_cache__ = {} except KeyError: pass ctx = _SkipBackendContext(backend) backend.__ua_cache__["skip"] = ctx return ctx def get_defaults(f): sig = inspect.signature(f) kw_defaults = {} arg_defaults = [] opts = set() for k, v in sig.parameters.items(): if v.default is not inspect.Parameter.empty: kw_defaults[k] = v.default if v.kind in ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ): arg_defaults.append(v.default) opts.add(k) return kw_defaults, tuple(arg_defaults), opts def set_global_backend(backend, coerce=False, only=False, *, try_last=False): """ This utility method replaces the default backend for permanent use. It will be tried in the list of backends automatically, unless the ``only`` flag is set on a backend. This will be the first tried backend outside the :obj:`set_backend` context manager. Note that this method is not thread-safe. .. warning:: We caution library authors against using this function in their code. We do *not* support this use-case. This function is meant to be used only by users themselves, or by a reference implementation, if one exists. Parameters ---------- backend The backend to register. coerce : bool Whether to coerce input types when trying this backend. only : bool If ``True``, no more backends will be tried if this fails. Implied by ``coerce=True``. try_last : bool If ``True``, the global backend is tried after registered backends. See Also -------- set_backend: A context manager that allows setting of backends. skip_backend: A context manager that allows skipping of backends. """ _uarray.set_global_backend(backend, coerce, only, try_last) def register_backend(backend): """ This utility method sets registers backend for permanent use. It will be tried in the list of backends automatically, unless the ``only`` flag is set on a backend. Note that this method is not thread-safe. Parameters ---------- backend The backend to register. """ _uarray.register_backend(backend) def clear_backends(domain, registered=True, globals=False): """ This utility method clears registered backends. .. warning:: We caution library authors against using this function in their code. We do *not* support this use-case. This function is meant to be used only by users themselves. .. warning:: Do NOT use this method inside a multimethod call, or the program is likely to crash. Parameters ---------- domain : Optional[str] The domain for which to de-register backends. ``None`` means de-register for all domains. registered : bool Whether or not to clear registered backends. See :obj:`register_backend`. globals : bool Whether or not to clear global backends. See :obj:`set_global_backend`. See Also -------- register_backend : Register a backend globally. set_global_backend : Set a global backend. """ _uarray.clear_backends(domain, registered, globals) class Dispatchable: """ A utility class which marks an argument with a specific dispatch type. Attributes ---------- value The value of the Dispatchable. type The type of the Dispatchable. Examples -------- >>> x = Dispatchable(1, str) >>> x , value=1> See Also -------- all_of_type Marks all unmarked parameters of a function. mark_as Allows one to create a utility function to mark as a given type. """ def __init__(self, value, dispatch_type, coercible=True): self.value = value self.type = dispatch_type self.coercible = coercible def __getitem__(self, index): return (self.type, self.value)[index] def __str__(self): return f"<{type(self).__name__}: type={self.type!r}, value={self.value!r}>" __repr__ = __str__ def mark_as(dispatch_type): """ Creates a utility function to mark something as a specific type. Examples -------- >>> mark_int = mark_as(int) >>> mark_int(1) , value=1> """ return functools.partial(Dispatchable, dispatch_type=dispatch_type) def all_of_type(arg_type): """ Marks all unmarked arguments as a given type. Examples -------- >>> @all_of_type(str) ... def f(a, b): ... return a, Dispatchable(b, int) >>> f('a', 1) (, value='a'>, , value=1>) """ def outer(func): @functools.wraps(func) def inner(*args, **kwargs): extracted_args = func(*args, **kwargs) return tuple( Dispatchable(arg, arg_type) if not isinstance(arg, Dispatchable) else arg for arg in extracted_args ) return inner return outer def wrap_single_convertor(convert_single): """ Wraps a ``__ua_convert__`` defined for a single element to all elements. If any of them return ``NotImplemented``, the operation is assumed to be undefined. Accepts a signature of (value, type, coerce). """ @functools.wraps(convert_single) def __ua_convert__(dispatchables, coerce): converted = [] for d in dispatchables: c = convert_single(d.value, d.type, coerce and d.coercible) if c is NotImplemented: return NotImplemented converted.append(c) return converted return __ua_convert__ def wrap_single_convertor_instance(convert_single): """ Wraps a ``__ua_convert__`` defined for a single element to all elements. If any of them return ``NotImplemented``, the operation is assumed to be undefined. Accepts a signature of (value, type, coerce). """ @functools.wraps(convert_single) def __ua_convert__(self, dispatchables, coerce): converted = [] for d in dispatchables: c = convert_single(self, d.value, d.type, coerce and d.coercible) if c is NotImplemented: return NotImplemented converted.append(c) return converted return __ua_convert__ def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False): """Set the backend to the first active backend that supports ``value`` This is useful for functions that call multimethods without any dispatchable arguments. You can use :func:`determine_backend` to ensure the same backend is used everywhere in a block of multimethod calls. Parameters ---------- value The value being tested dispatch_type The dispatch type associated with ``value``, aka ":ref:`marking `". domain: string The domain to query for backends and set. coerce: bool Whether or not to allow coercion to the backend's types. Implies ``only``. only: bool Whether or not this should be the last backend to try. See Also -------- set_backend: For when you know which backend to set Notes ----- Support is determined by the ``__ua_convert__`` protocol. Backends not supporting the type must return ``NotImplemented`` from their ``__ua_convert__`` if they don't support input of that type. Examples -------- Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting different types, ``TypeA`` and ``TypeB``. Neither supporting the other type: >>> with ua.set_backend(ex.BackendA): ... ex.call_multimethod(ex.TypeB(), ex.TypeB()) Traceback (most recent call last): ... uarray.BackendNotImplementedError: ... Now consider a multimethod that creates a new object of ``TypeA``, or ``TypeB`` depending on the active backend. >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB): ... res = ex.creation_multimethod() ... ex.call_multimethod(res, ex.TypeA()) Traceback (most recent call last): ... uarray.BackendNotImplementedError: ... ``res`` is an object of ``TypeB`` because ``BackendB`` is set in the innermost with statement. So, ``call_multimethod`` fails since the types don't match. Instead, we need to first find a backend suitable for all of our objects. >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB): ... x = ex.TypeA() ... with ua.determine_backend(x, "mark", domain="ua_examples"): ... res = ex.creation_multimethod() ... ex.call_multimethod(res, x) TypeA """ dispatchables = (Dispatchable(value, dispatch_type, coerce),) backend = _uarray.determine_backend(domain, dispatchables, coerce) return set_backend(backend, coerce=coerce, only=only) def determine_backend_multi( dispatchables, *, domain, only=True, coerce=False, **kwargs ): """Set a backend supporting all ``dispatchables`` This is useful for functions that call multimethods without any dispatchable arguments. You can use :func:`determine_backend_multi` to ensure the same backend is used everywhere in a block of multimethod calls involving multiple arrays. Parameters ---------- dispatchables: Sequence[Union[uarray.Dispatchable, Any]] The dispatchables that must be supported domain: string The domain to query for backends and set. coerce: bool Whether or not to allow coercion to the backend's types. Implies ``only``. only: bool Whether or not this should be the last backend to try. dispatch_type: Optional[Any] The default dispatch type associated with ``dispatchables``, aka ":ref:`marking `". See Also -------- determine_backend: For a single dispatch value set_backend: For when you know which backend to set Notes ----- Support is determined by the ``__ua_convert__`` protocol. Backends not supporting the type must return ``NotImplemented`` from their ``__ua_convert__`` if they don't support input of that type. Examples -------- :func:`determine_backend` allows the backend to be set from a single object. :func:`determine_backend_multi` allows multiple objects to be checked simultaneously for support in the backend. Suppose we have a ``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call, and a ``BackendBC`` that doesn't support ``TypeA``. >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC): ... a, b = ex.TypeA(), ex.TypeB() ... with ua.determine_backend_multi( ... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")], ... domain="ua_examples" ... ): ... res = ex.creation_multimethod() ... ex.call_multimethod(res, a, b) TypeA This won't call ``BackendBC`` because it doesn't support ``TypeA``. We can also use leave out the ``ua.Dispatchable`` if we specify the default ``dispatch_type`` for the ``dispatchables`` argument. >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC): ... a, b = ex.TypeA(), ex.TypeB() ... with ua.determine_backend_multi( ... [a, b], dispatch_type="mark", domain="ua_examples" ... ): ... res = ex.creation_multimethod() ... ex.call_multimethod(res, a, b) TypeA """ if "dispatch_type" in kwargs: disp_type = kwargs.pop("dispatch_type") dispatchables = tuple( d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type) for d in dispatchables ) else: dispatchables = tuple(dispatchables) if not all(isinstance(d, Dispatchable) for d in dispatchables): raise TypeError("dispatchables must be instances of uarray.Dispatchable") if len(kwargs) != 0: raise TypeError(f"Received unexpected keyword arguments: {kwargs}") backend = _uarray.determine_backend(domain, dispatchables, coerce) return set_backend(backend, coerce=coerce, only=only)