import re import copy import inspect import ast import textwrap """ Utilities for manipulating the Abstract Syntax Tree of Python constructs """ class NameVisitor(ast.NodeVisitor): """ NodeVisitor that builds a set of all of the named identifiers in an AST """ def __init__(self, *args, **kwargs): super(NameVisitor, self).__init__(*args, **kwargs) self.names = set() def visit_Name(self, node): self.names.add(node.id) def visit_arg(self, node): if hasattr(node, 'arg'): self.names.add(node.arg) elif hasattr(node, 'id'): self.names.add(node.id) def get_new_names(self, num_names): """ Returns a list of new names that are not already present in the AST. New names will have the form _N, for N a non-negative integer. If the AST has no existing identifiers of this form, then the returned names will start at 0 ('_0', '_1', '_2'). If the AST already has identifiers of this form, then the names returned will not include the existing identifiers. Parameters ---------- num_names: int The number of new names to return Returns ------- list of str """ prop_re = re.compile(r"^_(\d+)$") matching_names = [n for n in self.names if prop_re.match(n)] if matching_names: start_number = max([int(n[1:]) for n in matching_names]) + 1 else: start_number = 0 return ["_" + str(n) for n in range(start_number, start_number + num_names)] class ExpandVarargTransformer(ast.NodeTransformer): """ Node transformer that replaces the starred use of a variable in an AST with a collection of unstarred named variables. """ def __init__(self, starred_name, expand_names, *args, **kwargs): """ Parameters ---------- starred_name: str The name of the starred variable to replace expand_names: list of stf List of the new names that should be used to replace the starred variable """ super(ExpandVarargTransformer, self).__init__(*args, **kwargs) self.starred_name = starred_name self.expand_names = expand_names class ExpandVarargTransformerStarred(ExpandVarargTransformer): # Python 3 def visit_Starred(self, node): if node.value.id == self.starred_name: return [ast.Name(id=name, ctx=node.ctx) for name in self.expand_names] else: return node class ExpandVarargTransformerCallArg(ExpandVarargTransformer): # Python 2 def visit_Call(self, node): if getattr(node, 'starargs', None) and node.starargs.id == self.starred_name: node.starargs = None node.args.extend([ast.Name(id=name, ctx=ast.Load()) for name in self.expand_names]) return node else: return node def function_to_ast(fn): """ Get the AST representation of a function """ # Get source code for function # Dedent is needed if this is a nested function fn_source = textwrap.dedent(inspect.getsource(fn)) # Parse function source code into an AST fn_ast = ast.parse(fn_source) # # The function will be the fist element of the module body # fn_ast = module_ast.body[0] return fn_ast def ast_to_source(ast): """Convert AST to source code string using the astor package""" import astor return astor.to_source(ast) def compile_function_ast(fn_ast): """ Compile function AST into a code object suitable for use in eval/exec """ assert isinstance(fn_ast, ast.Module) fndef_ast = fn_ast.body[0] assert isinstance(fndef_ast, ast.FunctionDef) return compile(fn_ast, "<%s>" % fndef_ast.name, mode='exec') def function_ast_to_function(fn_ast, stacklevel=1): # Validate assert isinstance(fn_ast, ast.Module) fndef_ast = fn_ast.body[0] assert isinstance(fndef_ast, ast.FunctionDef) # Compile AST to code object code = compile_function_ast(fn_ast) # Evaluate the function in a scope that includes the globals and # locals of desired frame. current_frame = inspect.currentframe() eval_frame = current_frame for _ in range(stacklevel): eval_frame = eval_frame.f_back eval_locals = eval_frame.f_locals eval_globals = eval_frame.f_globals del current_frame scope = copy.copy(eval_globals) scope.update(eval_locals) # Evaluate function in scope eval(code, scope) # Return the newly evaluated function from the scope return scope[fndef_ast.name] def _build_arg(name): try: # Python 3 return ast.arg(arg=name) except AttributeError: # Python 2 return ast.Name(id=name, ctx=ast.Param()) def expand_function_ast_varargs(fn_ast, expand_number): """ Given a function AST that use a variable length positional argument (e.g. *args), return a function that replaces the use of this argument with one or more fixed arguments. To be supported, a function must have a starred argument in the function signature, and it may only use this argument in starred form as the input to other functions. For example, suppose expand_number is 3 and fn_ast is an AST representing this function... def my_fn1(a, b, *args): print(a, b) other_fn(a, b, *args) Then this function will return the AST of a function equivalent to... def my_fn1(a, b, _0, _1, _2): print(a, b) other_fn(a, b, _0, _1, _2) If the input function uses `args` for anything other than passing it to other functions in starred form, an error will be raised. Parameters ---------- fn_ast: ast.FunctionDef expand_number: int Returns ------- ast.FunctionDef """ assert isinstance(fn_ast, ast.Module) # Copy ast so we don't modify the input fn_ast = copy.deepcopy(fn_ast) # Extract function definition fndef_ast = fn_ast.body[0] assert isinstance(fndef_ast, ast.FunctionDef) # Get function args fn_args = fndef_ast.args # Function variable arity argument fn_vararg = fn_args.vararg # Require vararg if not fn_vararg: raise ValueError("""\ Input function AST does not have a variable length positional argument (e.g. *args) in the function signature""") assert fn_vararg # Get vararg name if isinstance(fn_vararg, str): vararg_name = fn_vararg else: vararg_name = fn_vararg.arg # Compute new unique names to use in place of the variable argument before_name_visitor = NameVisitor() before_name_visitor.visit(fn_ast) expand_names = before_name_visitor.get_new_names(expand_number) # Replace use of *args in function body if hasattr(ast, "Starred"): # Python 3 expand_transformer = ExpandVarargTransformerStarred else: # Python 2 expand_transformer = ExpandVarargTransformerCallArg new_fn_ast = expand_transformer( vararg_name, expand_names ).visit(fn_ast) new_fndef_ast = new_fn_ast.body[0] # Replace vararg with additional args in function signature new_fndef_ast.args.args.extend( [_build_arg(name=name) for name in expand_names] ) new_fndef_ast.args.vararg = None # Run a new NameVistor an see if there were any other non-starred uses # of the variable length argument. If so, raise an exception after_name_visitor = NameVisitor() after_name_visitor.visit(new_fn_ast) if vararg_name in after_name_visitor.names: raise ValueError("""\ The variable length positional argument {n} is used in an unsupported context """.format(n=vararg_name)) # Remove decorators if present to avoid recursion fndef_ast.decorator_list = [] # Add missing source code locations ast.fix_missing_locations(new_fn_ast) # Return result return new_fn_ast def expand_varargs(expand_number): """ Decorator to expand the variable length (starred) argument in a function signature with a fixed number of arguments. Parameters ---------- expand_number: int The number of fixed arguments that should replace the variable length argument Returns ------- function Decorator Function """ if not isinstance(expand_number, int) or expand_number < 0: raise ValueError("expand_number must be a non-negative integer") def _expand_varargs(fn): fn_ast = function_to_ast(fn) fn_expanded_ast = expand_function_ast_varargs(fn_ast, expand_number) return function_ast_to_function(fn_expanded_ast, stacklevel=2) return _expand_varargs