# -*- coding: utf-8 -*- import collections import functools import sys import types as pytypes import uuid import weakref from contextlib import ExitStack from numba import _dispatcher from numba.core import ( utils, types, errors, typing, serialize, config, compiler, sigutils ) from numba.core.compiler_lock import global_compiler_lock from numba.core.typeconv.rules import default_type_manager from numba.core.typing.templates import fold_arguments from numba.core.typing.typeof import Purpose, typeof from numba.core.bytecode import get_code_object from numba.core.caching import NullCache, FunctionCache from numba.core import entrypoints from numba.core.retarget import BaseRetarget import numba.core.event as ev class _RetargetStack(utils.ThreadLocalStack, stack_name="retarget"): def push(self, state): super().push(state) _dispatcher.set_use_tls_target_stack(len(self) > 0) def pop(self): super().pop() _dispatcher.set_use_tls_target_stack(len(self) > 0) class TargetConfigurationStack: """The target configuration stack. Uses the BORG pattern and stores states in threadlocal storage. WARNING: features associated with this class are experimental. The API may change without notice. """ def __init__(self): self._stack = _RetargetStack() def get(self): """Get the current target from the top of the stack. May raise IndexError if the stack is empty. Users should check the size of the stack beforehand. """ return self._stack.top() def __len__(self): """Size of the stack """ return len(self._stack) @classmethod def switch_target(cls, retarget: BaseRetarget): """Returns a contextmanager that pushes a new retarget handler, an instance of `numba.core.retarget.BaseRetarget`, onto the target-config stack for the duration of the context-manager. """ return cls()._stack.enter(retarget) class OmittedArg(object): """ A placeholder for omitted arguments with a default value. """ def __init__(self, value): self.value = value def __repr__(self): return "omitted arg(%r)" % (self.value,) @property def _numba_type_(self): return types.Omitted(self.value) class _FunctionCompiler(object): def __init__(self, py_func, targetdescr, targetoptions, locals, pipeline_class): self.py_func = py_func self.targetdescr = targetdescr self.targetoptions = targetoptions self.locals = locals self.pysig = utils.pysignature(self.py_func) self.pipeline_class = pipeline_class # Remember key=(args, return_type) combinations that will fail # compilation to avoid compilation attempt on them. The values are # the exceptions. self._failed_cache = {} def fold_argument_types(self, args, kws): """ Given positional and named argument types, fold keyword arguments and resolve defaults by inserting types.Omitted() instances. A (pysig, argument types) tuple is returned. """ def normal_handler(index, param, value): return value def default_handler(index, param, default): return types.Omitted(default) def stararg_handler(index, param, values): return types.StarArgTuple(values) # For now, we take argument values from the @jit function, even # in the case of generated jit. args = fold_arguments(self.pysig, args, kws, normal_handler, default_handler, stararg_handler) return self.pysig, args def compile(self, args, return_type): status, retval = self._compile_cached(args, return_type) if status: return retval else: raise retval def _compile_cached(self, args, return_type): key = tuple(args), return_type try: return False, self._failed_cache[key] except KeyError: pass try: retval = self._compile_core(args, return_type) except errors.TypingError as e: self._failed_cache[key] = e return False, e else: return True, retval def _compile_core(self, args, return_type): flags = compiler.Flags() self.targetdescr.options.parse_as_flags(flags, self.targetoptions) flags = self._customize_flags(flags) impl = self._get_implementation(args, {}) cres = compiler.compile_extra(self.targetdescr.typing_context, self.targetdescr.target_context, impl, args=args, return_type=return_type, flags=flags, locals=self.locals, pipeline_class=self.pipeline_class) # Check typing error if object mode is used if cres.typing_error is not None and not flags.enable_pyobject: raise cres.typing_error return cres def get_globals_for_reduction(self): return serialize._get_function_globals_for_reduction(self.py_func) def _get_implementation(self, args, kws): return self.py_func def _customize_flags(self, flags): return flags class _GeneratedFunctionCompiler(_FunctionCompiler): def __init__(self, py_func, targetdescr, targetoptions, locals, pipeline_class): super(_GeneratedFunctionCompiler, self).__init__( py_func, targetdescr, targetoptions, locals, pipeline_class) self.impls = set() def get_globals_for_reduction(self): # This will recursively get the globals used by any nested # implementation function. return serialize._get_function_globals_for_reduction(self.py_func) def _get_implementation(self, args, kws): impl = self.py_func(*args, **kws) # Check the generating function and implementation signatures are # compatible, otherwise compiling would fail later. pysig = utils.pysignature(self.py_func) implsig = utils.pysignature(impl) ok = len(pysig.parameters) == len(implsig.parameters) if ok: for pyparam, implparam in zip(pysig.parameters.values(), implsig.parameters.values()): # We allow the implementation to omit default values, but # if it mentions them, they should have the same value... if (pyparam.name != implparam.name or pyparam.kind != implparam.kind or (implparam.default is not implparam.empty and implparam.default != pyparam.default)): ok = False if not ok: raise TypeError("generated implementation %s should be compatible " "with signature '%s', but has signature '%s'" % (impl, pysig, implsig)) self.impls.add(impl) return impl _CompileStats = collections.namedtuple( '_CompileStats', ('cache_path', 'cache_hits', 'cache_misses')) class CompilingCounter(object): """ A simple counter that increment in __enter__ and decrement in __exit__. """ def __init__(self): self.counter = 0 def __enter__(self): assert self.counter >= 0 self.counter += 1 def __exit__(self, *args, **kwargs): self.counter -= 1 assert self.counter >= 0 def __bool__(self): return self.counter > 0 __nonzero__ = __bool__ class _DispatcherBase(_dispatcher.Dispatcher): """ Common base class for dispatcher Implementations. """ __numba__ = "py_func" def __init__(self, arg_count, py_func, pysig, can_fallback, exact_match_required): self._tm = default_type_manager # A mapping of signatures to compile results self.overloads = collections.OrderedDict() self.py_func = py_func # other parts of Numba assume the old Python 2 name for code object self.func_code = get_code_object(py_func) # but newer python uses a different name self.__code__ = self.func_code # a place to keep an active reference to the types of the active call self._types_active_call = [] # Default argument values match the py_func self.__defaults__ = py_func.__defaults__ argnames = tuple(pysig.parameters) default_values = self.py_func.__defaults__ or () defargs = tuple(OmittedArg(val) for val in default_values) try: lastarg = list(pysig.parameters.values())[-1] except IndexError: has_stararg = False else: has_stararg = lastarg.kind == lastarg.VAR_POSITIONAL _dispatcher.Dispatcher.__init__(self, self._tm.get_pointer(), arg_count, self._fold_args, argnames, defargs, can_fallback, has_stararg, exact_match_required) self.doc = py_func.__doc__ self._compiling_counter = CompilingCounter() weakref.finalize(self, self._make_finalizer()) def _compilation_chain_init_hook(self): """ This will be called ahead of any part of compilation taking place (this even includes being ahead of working out the types of the arguments). This permits activities such as initialising extension entry points so that the compiler knows about additional externally defined types etc before it does anything. """ entrypoints.init_all() def _reset_overloads(self): self._clear() self.overloads.clear() def _make_finalizer(self): """ Return a finalizer function that will release references to related compiled functions. """ overloads = self.overloads targetctx = self.targetctx # Early-bind utils.shutting_down() into the function's local namespace # (see issue #689) def finalizer(shutting_down=utils.shutting_down): # The finalizer may crash at shutdown, skip it (resources # will be cleared by the process exiting, anyway). if shutting_down(): return # This function must *not* hold any reference to self: # we take care to bind the necessary objects in the closure. for cres in overloads.values(): try: targetctx.remove_user_function(cres.entry_point) except KeyError: pass return finalizer @property def signatures(self): """ Returns a list of compiled function signatures. """ return list(self.overloads) @property def nopython_signatures(self): return [cres.signature for cres in self.overloads.values() if not cres.objectmode] def disable_compile(self, val=True): """Disable the compilation of new signatures at call time. """ # If disabling compilation then there must be at least one signature assert (not val) or len(self.signatures) > 0 self._can_compile = not val def add_overload(self, cres): args = tuple(cres.signature.args) sig = [a._code for a in args] self._insert(sig, cres.entry_point, cres.objectmode) self.overloads[args] = cres def fold_argument_types(self, args, kws): return self._compiler.fold_argument_types(args, kws) def get_call_template(self, args, kws): """ Get a typing.ConcreteTemplate for this dispatcher and the given *args* and *kws* types. This allows to resolve the return type. A (template, pysig, args, kws) tuple is returned. """ # XXX how about a dispatcher template class automating the # following? # Fold keyword arguments and resolve default values pysig, args = self._compiler.fold_argument_types(args, kws) kws = {} # Ensure an overload is available if self._can_compile: self.compile(tuple(args)) # Create function type for typing func_name = self.py_func.__name__ name = "CallTemplate({0})".format(func_name) # The `key` isn't really used except for diagnosis here, # so avoid keeping a reference to `cfunc`. call_template = typing.make_concrete_template( name, key=func_name, signatures=self.nopython_signatures) return call_template, pysig, args, kws def get_overload(self, sig): """ Return the compiled function for the given signature. """ args, return_type = sigutils.normalize_signature(sig) return self.overloads[tuple(args)].entry_point @property def is_compiling(self): """ Whether a specialization is currently being compiled. """ return self._compiling_counter def _compile_for_args(self, *args, **kws): """ For internal use. Compile a specialized version of the function for the given *args* and *kws*, and return the resulting callable. """ assert not kws # call any initialisation required for the compilation chain (e.g. # extension point registration). self._compilation_chain_init_hook() def error_rewrite(e, issue_type): """ Rewrite and raise Exception `e` with help supplied based on the specified issue_type. """ if config.SHOW_HELP: help_msg = errors.error_extras[issue_type] e.patch_message('\n'.join((str(e).rstrip(), help_msg))) if config.FULL_TRACEBACKS: raise e else: raise e.with_traceback(None) argtypes = [] for a in args: if isinstance(a, OmittedArg): argtypes.append(types.Omitted(a.value)) else: argtypes.append(self.typeof_pyval(a)) return_val = None try: return_val = self.compile(tuple(argtypes)) except errors.ForceLiteralArg as e: # Received request for compiler re-entry with the list of arguments # indicated by e.requested_args. # First, check if any of these args are already Literal-ized already_lit_pos = [i for i in e.requested_args if isinstance(args[i], types.Literal)] if already_lit_pos: # Abort compilation if any argument is already a Literal. # Letting this continue will cause infinite compilation loop. m = ("Repeated literal typing request.\n" "{}.\n" "This is likely caused by an error in typing. " "Please see nested and suppressed exceptions.") info = ', '.join('Arg #{} is {}'.format(i, args[i]) for i in sorted(already_lit_pos)) raise errors.CompilerError(m.format(info)) # Convert requested arguments into a Literal. args = [(types.literal if i in e.requested_args else lambda x: x)(args[i]) for i, v in enumerate(args)] # Re-enter compilation with the Literal-ized arguments return_val = self._compile_for_args(*args) except errors.TypingError as e: # Intercept typing error that may be due to an argument # that failed inferencing as a Numba type failed_args = [] for i, arg in enumerate(args): val = arg.value if isinstance(arg, OmittedArg) else arg try: tp = typeof(val, Purpose.argument) except ValueError as typeof_exc: failed_args.append((i, str(typeof_exc))) else: if tp is None: failed_args.append( (i, f"cannot determine Numba type of value {val}")) if failed_args: # Patch error message to ease debugging args_str = "\n".join( f"- argument {i}: {err}" for i, err in failed_args ) msg = (f"{str(e).rstrip()} \n\nThis error may have been caused " f"by the following argument(s):\n{args_str}\n") e.patch_message(msg) error_rewrite(e, 'typing') except errors.UnsupportedError as e: # Something unsupported is present in the user code, add help info error_rewrite(e, 'unsupported_error') except (errors.NotDefinedError, errors.RedefinedError, errors.VerificationError) as e: # These errors are probably from an issue with either the code # supplied being syntactically or otherwise invalid error_rewrite(e, 'interpreter') except errors.ConstantInferenceError as e: # this is from trying to infer something as constant when it isn't # or isn't supported as a constant error_rewrite(e, 'constant_inference') except Exception as e: if config.SHOW_HELP: if hasattr(e, 'patch_message'): help_msg = errors.error_extras['reportable'] e.patch_message('\n'.join((str(e).rstrip(), help_msg))) # ignore the FULL_TRACEBACKS config, this needs reporting! raise e finally: self._types_active_call = [] return return_val def inspect_llvm(self, signature=None): """Get the LLVM intermediate representation generated by compilation. Parameters ---------- signature : tuple of numba types, optional Specify a signature for which to obtain the LLVM IR. If None, the IR is returned for all available signatures. Returns ------- llvm : dict[signature, str] or str Either the LLVM IR string for the specified signature, or, if no signature was given, a dictionary mapping signatures to LLVM IR strings. """ if signature is not None: lib = self.overloads[signature].library return lib.get_llvm_str() return dict((sig, self.inspect_llvm(sig)) for sig in self.signatures) def inspect_asm(self, signature=None): """Get the generated assembly code. Parameters ---------- signature : tuple of numba types, optional Specify a signature for which to obtain the assembly code. If None, the assembly code is returned for all available signatures. Returns ------- asm : dict[signature, str] or str Either the assembly code for the specified signature, or, if no signature was given, a dictionary mapping signatures to assembly code. """ if signature is not None: lib = self.overloads[signature].library return lib.get_asm_str() return dict((sig, self.inspect_asm(sig)) for sig in self.signatures) def inspect_types(self, file=None, signature=None, pretty=False, style='default', **kwargs): """Print/return Numba intermediate representation (IR)-annotated code. Parameters ---------- file : file-like object, optional File to which to print. Defaults to sys.stdout if None. Must be None if ``pretty=True``. signature : tuple of numba types, optional Print/return the intermediate representation for only the given signature. If None, the IR is printed for all available signatures. pretty : bool, optional If True, an Annotate object will be returned that can render the IR with color highlighting in Jupyter and IPython. ``file`` must be None if ``pretty`` is True. Additionally, the ``pygments`` library must be installed for ``pretty=True``. style : str, optional Choose a style for rendering. Ignored if ``pretty`` is ``False``. This is directly consumed by ``pygments`` formatters. To see a list of available styles, import ``pygments`` and run ``list(pygments.styles.get_all_styles())``. Returns ------- annotated : Annotate object, optional Only returned if ``pretty=True``, otherwise this function is only used for its printing side effect. If ``pretty=True``, an Annotate object is returned that can render itself in Jupyter and IPython. """ overloads = self.overloads if signature is not None: overloads = {signature: self.overloads[signature]} if not pretty: if file is None: file = sys.stdout for ver, res in overloads.items(): print("%s %s" % (self.py_func.__name__, ver), file=file) print('-' * 80, file=file) print(res.type_annotation, file=file) print('=' * 80, file=file) else: if file is not None: raise ValueError("`file` must be None if `pretty=True`") from numba.core.annotations.pretty_annotate import Annotate return Annotate(self, signature=signature, style=style) def inspect_cfg(self, signature=None, show_wrapper=None, **kwargs): """ For inspecting the CFG of the function. By default the CFG of the user function is shown. The *show_wrapper* option can be set to "python" or "cfunc" to show the python wrapper function or the *cfunc* wrapper function, respectively. Parameters accepted in kwargs ----------------------------- filename : string, optional the name of the output file, if given this will write the output to filename view : bool, optional whether to immediately view the optional output file highlight : bool, set, dict, optional what, if anything, to highlight, options are: { incref : bool, # highlight NRT_incref calls decref : bool, # highlight NRT_decref calls returns : bool, # highlight exits which are normal returns raises : bool, # highlight exits which are from raise meminfo : bool, # highlight calls to NRT*meminfo branches : bool, # highlight true/false branches } Default is True which sets all of the above to True. Supplying a set of strings is also accepted, these are interpreted as key:True with respect to the above dictionary. e.g. {'incref', 'decref'} would switch on highlighting on increfs and decrefs. interleave: bool, set, dict, optional what, if anything, to interleave in the LLVM IR, options are: { python: bool # interleave python source code with the LLVM IR lineinfo: bool # interleave line information markers with the LLVM # IR } Default is True which sets all of the above to True. Supplying a set of strings is also accepted, these are interpreted as key:True with respect to the above dictionary. e.g. {'python',} would switch on interleaving of python source code in the LLVM IR. strip_ir : bool, optional Default is False. If set to True all LLVM IR that is superfluous to that requested in kwarg `highlight` will be removed. show_key : bool, optional Default is True. Create a "key" for the highlighting in the rendered CFG. fontsize : int, optional Default is 8. Set the fontsize in the output to this value. """ if signature is not None: cres = self.overloads[signature] lib = cres.library if show_wrapper == 'python': fname = cres.fndesc.llvm_cpython_wrapper_name elif show_wrapper == 'cfunc': fname = cres.fndesc.llvm_cfunc_wrapper_name else: fname = cres.fndesc.mangled_name return lib.get_function_cfg(fname, py_func=self.py_func, **kwargs) return dict((sig, self.inspect_cfg(sig, show_wrapper=show_wrapper)) for sig in self.signatures) def inspect_disasm_cfg(self, signature=None): """ For inspecting the CFG of the disassembly of the function. Requires python package: r2pipe Requires radare2 binary on $PATH. Notebook rendering requires python package: graphviz signature : tuple of Numba types, optional Print/return the disassembly CFG for only the given signatures. If None, the IR is printed for all available signatures. """ if signature is not None: cres = self.overloads[signature] lib = cres.library return lib.get_disasm_cfg(cres.fndesc.mangled_name) return dict((sig, self.inspect_disasm_cfg(sig)) for sig in self.signatures) def get_annotation_info(self, signature=None): """ Gets the annotation information for the function specified by signature. If no signature is supplied a dictionary of signature to annotation information is returned. """ signatures = self.signatures if signature is None else [signature] out = collections.OrderedDict() for sig in signatures: cres = self.overloads[sig] ta = cres.type_annotation key = (ta.func_id.filename + ':' + str(ta.func_id.firstlineno + 1), ta.signature) out[key] = ta.annotate_raw()[key] return out def _explain_ambiguous(self, *args, **kws): """ Callback for the C _Dispatcher object. """ assert not kws, "kwargs not handled" args = tuple([self.typeof_pyval(a) for a in args]) # The order here must be deterministic for testing purposes, which # is ensured by the OrderedDict. sigs = self.nopython_signatures # This will raise self.typingctx.resolve_overload(self.py_func, sigs, args, kws, allow_ambiguous=False) def _explain_matching_error(self, *args, **kws): """ Callback for the C _Dispatcher object. """ assert not kws, "kwargs not handled" args = [self.typeof_pyval(a) for a in args] msg = ("No matching definition for argument type(s) %s" % ', '.join(map(str, args))) raise TypeError(msg) def _search_new_conversions(self, *args, **kws): """ Callback for the C _Dispatcher object. Search for approximately matching signatures for the given arguments, and ensure the corresponding conversions are registered in the C++ type manager. """ assert not kws, "kwargs not handled" args = [self.typeof_pyval(a) for a in args] found = False for sig in self.nopython_signatures: conv = self.typingctx.install_possible_conversions(args, sig.args) if conv: found = True return found def __repr__(self): return "%s(%s)" % (type(self).__name__, self.py_func) def typeof_pyval(self, val): """ Resolve the Numba type of Python value *val*. This is called from numba._dispatcher as a fallback if the native code cannot decide the type. """ # Not going through the resolve_argument_type() indirection # can save a couple µs. try: tp = typeof(val, Purpose.argument) except ValueError: tp = types.pyobject else: if tp is None: tp = types.pyobject self._types_active_call.append(tp) return tp def _callback_add_timer(self, duration, cres, lock_name): md = cres.metadata # md can be None when code is loaded from cache if md is not None: timers = md.setdefault("timers", {}) if lock_name not in timers: # Only write if the metadata does not exist timers[lock_name] = duration else: msg = f"'{lock_name} metadata is already defined." raise AssertionError(msg) def _callback_add_compiler_timer(self, duration, cres): return self._callback_add_timer(duration, cres, lock_name="compiler_lock") def _callback_add_llvm_timer(self, duration, cres): return self._callback_add_timer(duration, cres, lock_name="llvm_lock") class _MemoMixin: __uuid = None # A {uuid -> instance} mapping, for deserialization _memo = weakref.WeakValueDictionary() # hold refs to last N functions deserialized, retaining them in _memo # regardless of whether there is another reference _recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE) @property def _uuid(self): """ An instance-specific UUID, to avoid multiple deserializations of a given instance. Note: this is lazily-generated, for performance reasons. """ u = self.__uuid if u is None: u = str(uuid.uuid1()) self._set_uuid(u) return u def _set_uuid(self, u): assert self.__uuid is None self.__uuid = u self._memo[u] = self self._recent.append(self) class Dispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase): """ Implementation of user-facing dispatcher objects (i.e. created using the @jit decorator). This is an abstract base class. Subclasses should define the targetdescr class attribute. """ _fold_args = True _impl_kinds = { 'direct': _FunctionCompiler, 'generated': _GeneratedFunctionCompiler, } __numba__ = 'py_func' def __init__(self, py_func, locals={}, targetoptions={}, impl_kind='direct', pipeline_class=compiler.Compiler): """ Parameters ---------- py_func: function object to be compiled locals: dict, optional Mapping of local variable names to Numba types. Used to override the types deduced by the type inference engine. targetoptions: dict, optional Target-specific config options. impl_kind: str Select the compiler mode for `@jit` and `@generated_jit` pipeline_class: type numba.compiler.CompilerBase The compiler pipeline type. """ self.typingctx = self.targetdescr.typing_context self.targetctx = self.targetdescr.target_context pysig = utils.pysignature(py_func) arg_count = len(pysig.parameters) can_fallback = not targetoptions.get('nopython', False) _DispatcherBase.__init__(self, arg_count, py_func, pysig, can_fallback, exact_match_required=False) functools.update_wrapper(self, py_func) self.targetoptions = targetoptions self.locals = locals self._cache = NullCache() compiler_class = self._impl_kinds[impl_kind] self._impl_kind = impl_kind self._compiler = compiler_class(py_func, self.targetdescr, targetoptions, locals, pipeline_class) self._cache_hits = collections.Counter() self._cache_misses = collections.Counter() self._type = types.Dispatcher(self) self.typingctx.insert_global(self, self._type) # Remember target restriction self._required_target_backend = targetoptions.get('target_backend') def dump(self, tab=''): print(f'{tab}DUMP {type(self).__name__}[{self.py_func.__name__}' f', type code={self._type._code}]') for cres in self.overloads.values(): cres.dump(tab=tab + ' ') print(f'{tab}END DUMP {type(self).__name__}[{self.py_func.__name__}]') @property def _numba_type_(self): return types.Dispatcher(self) def enable_caching(self): self._cache = FunctionCache(self.py_func) def __get__(self, obj, objtype=None): '''Allow a JIT function to be bound as a method to an object''' if obj is None: # Unbound method return self else: # Bound method return pytypes.MethodType(self, obj) def _reduce_states(self): """ Reduce the instance for pickling. This will serialize the original function as well the compilation options and compiled signatures, but not the compiled code itself. NOTE: part of ReduceMixin protocol """ if self._can_compile: sigs = [] else: sigs = [cr.signature for cr in self.overloads.values()] return dict( uuid=str(self._uuid), py_func=self.py_func, locals=self.locals, targetoptions=self.targetoptions, impl_kind=self._impl_kind, can_compile=self._can_compile, sigs=sigs, ) @classmethod def _rebuild(cls, uuid, py_func, locals, targetoptions, impl_kind, can_compile, sigs): """ Rebuild an Dispatcher instance after it was __reduce__'d. NOTE: part of ReduceMixin protocol """ try: return cls._memo[uuid] except KeyError: pass self = cls(py_func, locals, targetoptions, impl_kind) # Make sure this deserialization will be merged with subsequent ones self._set_uuid(uuid) for sig in sigs: self.compile(sig) self._can_compile = can_compile return self def compile(self, sig): disp = self._get_dispatcher_for_current_target() if disp is not self: return disp.compile(sig) with ExitStack() as scope: cres = None def cb_compiler(dur): if cres is not None: self._callback_add_compiler_timer(dur, cres) def cb_llvm(dur): if cres is not None: self._callback_add_llvm_timer(dur, cres) scope.enter_context(ev.install_timer("numba:compiler_lock", cb_compiler)) scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm)) scope.enter_context(global_compiler_lock) if not self._can_compile: raise RuntimeError("compilation disabled") # Use counter to track recursion compilation depth with self._compiling_counter: args, return_type = sigutils.normalize_signature(sig) # Don't recompile if signature already exists existing = self.overloads.get(tuple(args)) if existing is not None: return existing.entry_point # Try to load from disk cache cres = self._cache.load_overload(sig, self.targetctx) if cres is not None: self._cache_hits[sig] += 1 # XXX fold this in add_overload()? (also see compiler.py) if not cres.objectmode: self.targetctx.insert_user_function(cres.entry_point, cres.fndesc, [cres.library]) self.add_overload(cres) return cres.entry_point self._cache_misses[sig] += 1 ev_details = dict( dispatcher=self, args=args, return_type=return_type, ) with ev.trigger_event("numba:compile", data=ev_details): try: cres = self._compiler.compile(args, return_type) except errors.ForceLiteralArg as e: def folded(args, kws): return self._compiler.fold_argument_types(args, kws)[1] raise e.bind_fold_arguments(folded) self.add_overload(cres) self._cache.save_overload(sig, cres) return cres.entry_point def get_compile_result(self, sig): """Compile (if needed) and return the compilation result with the given signature. """ atypes = tuple(sig.args) if atypes not in self.overloads: self.compile(atypes) return self.overloads[atypes] def recompile(self): """ Recompile all signatures afresh. """ sigs = list(self.overloads) old_can_compile = self._can_compile # Ensure the old overloads are disposed of, # including compiled functions. self._make_finalizer()() self._reset_overloads() self._cache.flush() self._can_compile = True try: for sig in sigs: self.compile(sig) finally: self._can_compile = old_can_compile @property def stats(self): return _CompileStats( cache_path=self._cache.cache_path, cache_hits=self._cache_hits, cache_misses=self._cache_misses, ) def parallel_diagnostics(self, signature=None, level=1): """ Print parallel diagnostic information for the given signature. If no signature is present it is printed for all known signatures. level is used to adjust the verbosity, level=1 (default) is minimal verbosity, and 2, 3, and 4 provide increasing levels of verbosity. """ def dump(sig): ol = self.overloads[sig] pfdiag = ol.metadata.get('parfor_diagnostics', None) if pfdiag is None: msg = "No parfors diagnostic available, is 'parallel=True' set?" raise ValueError(msg) pfdiag.dump(level) if signature is not None: dump(signature) else: [dump(sig) for sig in self.signatures] def get_metadata(self, signature=None): """ Obtain the compilation metadata for a given signature. """ if signature is not None: return self.overloads[signature].metadata else: return dict( (sig,self.overloads[sig].metadata) for sig in self.signatures ) def get_function_type(self): """Return unique function type of dispatcher when possible, otherwise return None. A Dispatcher instance has unique function type when it contains exactly one compilation result and its compilation has been disabled (via its disable_compile method). """ if not self._can_compile and len(self.overloads) == 1: cres = tuple(self.overloads.values())[0] return types.FunctionType(cres.signature) def _get_retarget_dispatcher(self): """Returns a dispatcher for the retarget request. """ # Check TLS target configuration tc = TargetConfigurationStack() retarget = tc.get() retarget.check_compatible(self) disp = retarget.retarget(self) return disp def _get_dispatcher_for_current_target(self): """Returns a dispatcher for the current target registered in `TargetConfigurationStack`. `self` is returned if no target is specified. """ tc = TargetConfigurationStack() if tc: return self._get_retarget_dispatcher() else: return self def _call_tls_target(self, *args, **kwargs): """This is called when the C dispatcher logic sees a retarget request. """ disp = self._get_retarget_dispatcher() # Call the new dispatcher return disp(*args, **kwargs) class LiftedCode(serialize.ReduceMixin, _MemoMixin, _DispatcherBase): """ Implementation of the hidden dispatcher objects used for lifted code (a lifted loop is really compiled as a separate function). """ _fold_args = False can_cache = False def __init__(self, func_ir, typingctx, targetctx, flags, locals): self.func_ir = func_ir self.lifted_from = None self.typingctx = typingctx self.targetctx = targetctx self.flags = flags self.locals = locals _DispatcherBase.__init__(self, self.func_ir.arg_count, self.func_ir.func_id.func, self.func_ir.func_id.pysig, can_fallback=True, exact_match_required=False) def _reduce_states(self): """ Reduce the instance for pickling. This will serialize the original function as well the compilation options and compiled signatures, but not the compiled code itself. NOTE: part of ReduceMixin protocol """ return dict( uuid=self._uuid, func_ir=self.func_ir, flags=self.flags, locals=self.locals, extras=self._reduce_extras(), ) def _reduce_extras(self): """ NOTE: sub-class can override to add extra states """ return {} @classmethod def _rebuild(cls, uuid, func_ir, flags, locals, extras): """ Rebuild an Dispatcher instance after it was __reduce__'d. NOTE: part of ReduceMixin protocol """ try: return cls._memo[uuid] except KeyError: pass # NOTE: We are assuming that this is must be cpu_target, which is true # for now. # TODO: refactor this to not assume on `cpu_target` from numba.core import registry typingctx = registry.cpu_target.typing_context targetctx = registry.cpu_target.target_context self = cls(func_ir, typingctx, targetctx, flags, locals, **extras) self._set_uuid(uuid) return self def get_source_location(self): """Return the starting line number of the loop. """ return self.func_ir.loc.line def _pre_compile(self, args, return_type, flags): """Pre-compile actions """ pass def compile(self, sig): with ExitStack() as scope: cres = None def cb_compiler(dur): if cres is not None: self._callback_add_compiler_timer(dur, cres) def cb_llvm(dur): if cres is not None: self._callback_add_llvm_timer(dur, cres) scope.enter_context(ev.install_timer("numba:compiler_lock", cb_compiler)) scope.enter_context(ev.install_timer("numba:llvm_lock", cb_llvm)) scope.enter_context(global_compiler_lock) # Use counter to track recursion compilation depth with self._compiling_counter: # XXX this is mostly duplicated from Dispatcher. flags = self.flags args, return_type = sigutils.normalize_signature(sig) # Don't recompile if signature already exists # (e.g. if another thread compiled it before we got the lock) existing = self.overloads.get(tuple(args)) if existing is not None: return existing.entry_point self._pre_compile(args, return_type, flags) # Clone IR to avoid (some of the) mutation in the rewrite pass cloned_func_ir = self.func_ir.copy() ev_details = dict( dispatcher=self, args=args, return_type=return_type, ) with ev.trigger_event("numba:compile", data=ev_details): cres = compiler.compile_ir(typingctx=self.typingctx, targetctx=self.targetctx, func_ir=cloned_func_ir, args=args, return_type=return_type, flags=flags, locals=self.locals, lifted=(), lifted_from=self.lifted_from, is_lifted_loop=True,) # Check typing error if object mode is used if (cres.typing_error is not None and not flags.enable_pyobject): raise cres.typing_error self.add_overload(cres) return cres.entry_point def _get_dispatcher_for_current_target(self): # Lifted code does not honor the target switch currently. # No work has been done to check if this can be allowed. return self class LiftedLoop(LiftedCode): def _pre_compile(self, args, return_type, flags): assert not flags.enable_looplift, "Enable looplift flags is on" class LiftedWith(LiftedCode): can_cache = True def _reduce_extras(self): return dict(output_types=self.output_types) @property def _numba_type_(self): return types.Dispatcher(self) def get_call_template(self, args, kws): """ Get a typing.ConcreteTemplate for this dispatcher and the given *args* and *kws* types. This enables the resolving of the return type. A (template, pysig, args, kws) tuple is returned. """ # Ensure an overload is available if self._can_compile: self.compile(tuple(args)) pysig = None # Create function type for typing func_name = self.py_func.__name__ name = "CallTemplate({0})".format(func_name) # The `key` isn't really used except for diagnosis here, # so avoid keeping a reference to `cfunc`. call_template = typing.make_concrete_template( name, key=func_name, signatures=self.nopython_signatures) return call_template, pysig, args, kws class ObjModeLiftedWith(LiftedWith): def __init__(self, *args, **kwargs): self.output_types = kwargs.pop('output_types', None) super(LiftedWith, self).__init__(*args, **kwargs) if not self.flags.force_pyobject: raise ValueError("expecting `flags.force_pyobject`") if self.output_types is None: raise TypeError('`output_types` must be provided') # switch off rewrites, they have no effect self.flags.no_rewrites = True @property def _numba_type_(self): return types.ObjModeDispatcher(self) def get_call_template(self, args, kws): """ Get a typing.ConcreteTemplate for this dispatcher and the given *args* and *kws* types. This enables the resolving of the return type. A (template, pysig, args, kws) tuple is returned. """ assert not kws self._legalize_arg_types(args) # Coerce to object mode args = [types.ffi_forced_object] * len(args) if self._can_compile: self.compile(tuple(args)) signatures = [typing.signature(self.output_types, *args)] pysig = None func_name = self.py_func.__name__ name = "CallTemplate({0})".format(func_name) call_template = typing.make_concrete_template( name, key=func_name, signatures=signatures) return call_template, pysig, args, kws def _legalize_arg_types(self, args): for i, a in enumerate(args, start=1): if isinstance(a, types.List): msg = ( 'Does not support list type inputs into ' 'with-context for arg {}' ) raise errors.TypingError(msg.format(i)) elif isinstance(a, types.Dispatcher): msg = ( 'Does not support function type inputs into ' 'with-context for arg {}' ) raise errors.TypingError(msg.format(i)) @global_compiler_lock def compile(self, sig): args, _ = sigutils.normalize_signature(sig) sig = (types.ffi_forced_object,) * len(args) return super().compile(sig) # Initialize typeof machinery _dispatcher.typeof_init( OmittedArg, dict((str(t), t._code) for t in types.number_domain))