import timeit from abc import abstractmethod, ABCMeta from collections import namedtuple, OrderedDict import inspect from numba.core.compiler_lock import global_compiler_lock from numba.core import errors, config, transforms, utils from numba.core.tracing import event from numba.core.postproc import PostProcessor from numba.core.ir_utils import enforce_no_dels, legalize_single_scope # terminal color markup _termcolor = errors.termcolor() class SimpleTimer(object): """ A simple context managed timer """ def __enter__(self): self.ts = timeit.default_timer() return self def __exit__(self, *exc): self.elapsed = timeit.default_timer() - self.ts class CompilerPass(metaclass=ABCMeta): """ The base class for all compiler passes. """ @abstractmethod def __init__(self, *args, **kwargs): self._analysis = None self._pass_id = None @classmethod def name(cls): """ Returns the name of the pass """ return cls._name @property def pass_id(self): """ The ID of the pass """ return self._pass_id @pass_id.setter def pass_id(self, val): """ Sets the ID of the pass """ self._pass_id = val @property def analysis(self): """ Analysis data for the pass """ return self._analysis @analysis.setter def analysis(self, val): """ Set the analysis data for the pass """ self._analysis = val def run_initialization(self, *args, **kwargs): """ Runs the initialization sequence for the pass, will run before `run_pass`. """ return False @abstractmethod def run_pass(self, *args, **kwargs): """ Runs the pass itself. Must return True/False depending on whether statement level modification took place. """ pass def run_finalizer(self, *args, **kwargs): """ Runs the initialization sequence for the pass, will run before `run_pass`. """ return False def get_analysis_usage(self, AU): """ Override to set analysis usage """ pass def get_analysis(self, pass_name): """ Gets the analysis from a given pass """ return self._analysis[pass_name] class SSACompliantMixin(object): """ Mixin to indicate a pass is SSA form compliant. Nothing is asserted about this condition at present. """ pass class FunctionPass(CompilerPass): """ Base class for function passes """ pass class AnalysisPass(CompilerPass): """ Base class for analysis passes (no modification made to state) """ pass class LoweringPass(CompilerPass): """ Base class for lowering passes """ pass class AnalysisUsage(object): """This looks and behaves like LLVM's AnalysisUsage because its like that. """ def __init__(self): self._required = set() self._preserved = set() def get_required_set(self): return self._required def get_preserved_set(self): return self._preserved def add_required(self, pss): self._required.add(pss) def add_preserved(self, pss): self._preserved.add(pss) def __str__(self): return "required: %s\n" % self._required _DEBUG = False def debug_print(*args, **kwargs): if _DEBUG: print(*args, **kwargs) pass_timings = namedtuple('pass_timings', 'init run finalize') class PassManager(object): """ The PassManager is a named instance of a particular compilation pipeline """ # TODO: Eventually enable this, it enforces self consistency after each pass _ENFORCING = False def __init__(self, pipeline_name): """ Create a new pipeline with name "pipeline_name" """ self.passes = [] self.exec_times = OrderedDict() self._finalized = False self._analysis = None self._print_after = None self.pipeline_name = pipeline_name def _validate_pass(self, pass_cls): if (not (isinstance(pass_cls, str) or (inspect.isclass(pass_cls) and issubclass(pass_cls, CompilerPass)))): msg = ("Pass must be referenced by name or be a subclass of a " "CompilerPass. Have %s" % pass_cls) raise TypeError(msg) if isinstance(pass_cls, str): pass_cls = _pass_registry.find_by_name(pass_cls) else: if not _pass_registry.is_registered(pass_cls): raise ValueError("Pass %s is not registered" % pass_cls) def add_pass(self, pss, description=""): """ Append a pass to the PassManager's compilation pipeline """ self._validate_pass(pss) func_desc_tuple = (pss, description) self.passes.append(func_desc_tuple) self._finalized = False def add_pass_after(self, pass_cls, location): """ Add a pass `pass_cls` to the PassManager's compilation pipeline after the pass `location`. """ assert self.passes self._validate_pass(pass_cls) self._validate_pass(location) for idx, (x, _) in enumerate(self.passes): if x == location: break else: raise ValueError("Could not find pass %s" % location) self.passes.insert(idx + 1, (pass_cls, str(pass_cls))) # if a pass has been added, it's not finalized self._finalized = False def _debug_init(self): # determine after which passes IR dumps should take place def parse(conf_item): print_passes = [] if conf_item != "none": if conf_item == "all": print_passes = [x.name() for (x, _) in self.passes] else: # we don't validate whether the named passes exist in this # pipeline the compiler may be used reentrantly and # different pipelines may contain different passes splitted = conf_item.split(',') print_passes = [x.strip() for x in splitted] return print_passes ret = (parse(config.DEBUG_PRINT_AFTER), parse(config.DEBUG_PRINT_BEFORE), parse(config.DEBUG_PRINT_WRAP),) return ret def finalize(self): """ Finalize the PassManager, after which no more passes may be added without re-finalization. """ self._analysis = self.dependency_analysis() self._print_after, self._print_before, self._print_wrap = \ self._debug_init() self._finalized = True @property def finalized(self): return self._finalized def _patch_error(self, desc, exc): """ Patches the error to show the stage that it arose in. """ newmsg = "{desc}\n{exc}".format(desc=desc, exc=exc) exc.args = (newmsg,) return exc @global_compiler_lock # this need a lock, likely calls LLVM def _runPass(self, index, pss, internal_state): mutated = False def check(func, compiler_state): mangled = func(compiler_state) if mangled not in (True, False): msg = ("CompilerPass implementations should return True/False. " "CompilerPass with name '%s' did not.") raise ValueError(msg % pss.name()) return mangled def debug_print(pass_name, print_condition, printable_condition): if pass_name in print_condition: fid = internal_state.func_id args = (fid.modname, fid.func_qualname, self.pipeline_name, printable_condition, pass_name) print(("%s.%s: %s: %s %s" % args).center(120, '-')) if internal_state.func_ir is not None: internal_state.func_ir.dump() else: print("func_ir is None") # debug print before this pass? debug_print(pss.name(), self._print_before + self._print_wrap, "BEFORE") # wire in the analysis info so it's accessible pss.analysis = self._analysis with SimpleTimer() as init_time: mutated |= check(pss.run_initialization, internal_state) with SimpleTimer() as pass_time: mutated |= check(pss.run_pass, internal_state) with SimpleTimer() as finalize_time: mutated |= check(pss.run_finalizer, internal_state) # Check that if the pass is an instance of a FunctionPass that it hasn't # emitted ir.Dels. if isinstance(pss, FunctionPass): enforce_no_dels(internal_state.func_ir) if self._ENFORCING: # TODO: Add in self consistency enforcement for # `func_ir._definitions` etc if _pass_registry.get(pss.__class__).mutates_CFG: if mutated: # block level changes, rebuild all PostProcessor(internal_state.func_ir).run() else: # CFG level changes rebuild CFG internal_state.func_ir.blocks = transforms.canonicalize_cfg( internal_state.func_ir.blocks) # Check the func_ir has exactly one Scope instance if not legalize_single_scope(internal_state.func_ir.blocks): raise errors.CompilerError( f"multiple scope in func_ir detected in {pss}", ) # inject runtimes pt = pass_timings(init_time.elapsed, pass_time.elapsed, finalize_time.elapsed) self.exec_times["%s_%s" % (index, pss.name())] = pt # debug print after this pass? debug_print(pss.name(), self._print_after + self._print_wrap, "AFTER") def run(self, state): """ Run the defined pipelines on the state. """ from numba.core.compiler import _EarlyPipelineCompletion if not self.finalized: raise RuntimeError("Cannot run non-finalised pipeline") # walk the passes and run them for idx, (pss, pass_desc) in enumerate(self.passes): try: event("-- %s" % pass_desc) pass_inst = _pass_registry.get(pss).pass_inst if isinstance(pass_inst, CompilerPass): self._runPass(idx, pass_inst, state) else: raise BaseException("Legacy pass in use") except _EarlyPipelineCompletion as e: raise e except Exception as e: if (utils.use_new_style_errors() and not isinstance(e, errors.NumbaError)): raise e msg = "Failed in %s mode pipeline (step: %s)" % \ (self.pipeline_name, pass_desc) patched_exception = self._patch_error(msg, e) raise patched_exception def dependency_analysis(self): """ Computes dependency analysis """ deps = dict() for (pss, _) in self.passes: x = _pass_registry.get(pss).pass_inst au = AnalysisUsage() x.get_analysis_usage(au) deps[type(x)] = au requires_map = dict() for k, v in deps.items(): requires_map[k] = v.get_required_set() def resolve_requires(key, rmap): def walk(lkey, rmap): dep_set = rmap[lkey] if lkey in rmap else set() if dep_set: for x in dep_set: dep_set |= (walk(x, rmap)) return dep_set else: return set() ret = set() for k in key: ret |= walk(k, rmap) return ret dep_chain = dict() for k, v in requires_map.items(): dep_chain[k] = set(v) | (resolve_requires(v, requires_map)) return dep_chain pass_info = namedtuple('pass_info', 'pass_inst mutates_CFG analysis_only') class PassRegistry(object): """ Pass registry singleton class. """ _id = 0 _registry = dict() def register(self, mutates_CFG, analysis_only): def make_festive(pass_class): assert not self.is_registered(pass_class) assert not self._does_pass_name_alias(pass_class.name()) pass_class.pass_id = self._id self._id += 1 self._registry[pass_class] = pass_info(pass_class(), mutates_CFG, analysis_only) return pass_class return make_festive def is_registered(self, clazz): return clazz in self._registry.keys() def get(self, clazz): assert self.is_registered(clazz) return self._registry[clazz] def _does_pass_name_alias(self, check): for k, v in self._registry.items(): if v.pass_inst.name == check: return True return False def find_by_name(self, class_name): assert isinstance(class_name, str) for k, v in self._registry.items(): if v.pass_inst.name == class_name: return v else: raise ValueError("No pass with name %s is registered" % class_name) def dump(self): for k, v in self._registry.items(): print("%s: %s" % (k, v)) _pass_registry = PassRegistry() del PassRegistry """ register_pass is used to register a compiler pass class for use with PassManager instances. """ register_pass = _pass_registry.register