import collections import numbers import re import warnings from rope.base import ast, codeanalyze, exceptions from rope.base.utils import pycompat try: basestring except NameError: basestring = (str, bytes) COMMA_IN_WITH_PATTERN = re.compile(r"\(.*?\)|(,)") def get_patched_ast(source, sorted_children=False): """Adds ``region`` and ``sorted_children`` fields to nodes Adds ``sorted_children`` field only if `sorted_children` is True. """ return patch_ast(ast.parse(source), source, sorted_children) def patch_ast(node, source, sorted_children=False): """Patches the given node After calling, each node in `node` will have a new field named `region` that is a tuple containing the start and end offsets of the code that generated it. If `sorted_children` is true, a `sorted_children` field will be created for each node, too. It is a list containing child nodes as well as whitespaces and comments that occur between them. """ if hasattr(node, "region"): return node walker = _PatchingASTWalker(source, children=sorted_children) ast.call_for_nodes(node, walker) return node def node_region(patched_ast_node): """Get the region of a patched ast node""" return patched_ast_node.region def write_ast(patched_ast_node): """Extract source form a patched AST node with `sorted_children` field If the node is patched with sorted_children turned off you can use `node_region` function for obtaining code using module source code. """ result = [] for child in patched_ast_node.sorted_children: if isinstance(child, ast.AST): result.append(write_ast(child)) else: result.append(child) return "".join(result) class MismatchedTokenError(exceptions.RopeError): pass class _PatchingASTWalker(object): def __init__(self, source, children=False): self.source = _Source(source) self.children = children self.lines = codeanalyze.SourceLinesAdapter(source) self.children_stack = [] Number = object() String = object() semicolon_or_as_in_except = object() exec_open_paren_or_space = object() exec_close_paren_or_space = object() exec_in_or_comma = object() with_or_comma_context_manager = object() empty_tuple = object() def __call__(self, node): method = getattr(self, "_" + node.__class__.__name__, None) if method is not None: return method(node) # ???: Unknown node; what should we do here? warnings.warn( "Unknown node type <%s>; please report!" % node.__class__.__name__, RuntimeWarning, ) node.region = (self.source.offset, self.source.offset) if self.children: node.sorted_children = ast.get_children(node) def _handle(self, node, base_children, eat_parens=False, eat_spaces=False): if hasattr(node, "region"): # ???: The same node was seen twice; what should we do? warnings.warn( "Node <%s> has been already patched; please report!" % node.__class__.__name__, RuntimeWarning, ) return base_children = collections.deque(base_children) self.children_stack.append(base_children) children = collections.deque() formats = [] suspected_start = self.source.offset start = suspected_start first_token = True while base_children: child = base_children.popleft() if child is None: continue offset = self.source.offset if isinstance(child, ast.AST): ast.call_for_nodes(child, self) token_start = child.region[0] else: if child is self.String: region = self.source.consume_string( end=self._find_next_statement_start() ) elif child is self.Number: region = self.source.consume_number() elif child == self.empty_tuple: region = self.source.consume_empty_tuple() elif child == "!=": # INFO: This has been added to handle deprecated ``<>`` region = self.source.consume_not_equal() elif child == self.semicolon_or_as_in_except: # INFO: This has been added to handle deprecated # semicolon in except region = self.source.consume_except_as_or_semicolon() elif child == self.exec_open_paren_or_space: # These three cases handle the differences between # the deprecated exec statement and the exec # function. region = self.source.consume_exec_open_paren_or_space() elif child == self.exec_in_or_comma: region = self.source.consume_exec_in_or_comma() elif child == self.exec_close_paren_or_space: region = self.source.consume_exec_close_paren_or_space() elif child == self.with_or_comma_context_manager: region = self.source.consume_with_or_comma_context_manager() else: if hasattr(ast, "JoinedStr") and isinstance( node, (ast.JoinedStr, ast.FormattedValue) ): region = self.source.consume_joined_string(child) else: region = self.source.consume(child) child = self.source[region[0] : region[1]] token_start = region[0] if not first_token: formats.append(self.source[offset:token_start]) if self.children: children.append(self.source[offset:token_start]) else: first_token = False start = token_start if self.children: children.append(child) start = self._handle_parens(children, start, formats) if eat_parens: start = self._eat_surrounding_parens(children, suspected_start, start) if eat_spaces: if self.children: children.appendleft(self.source[0:start]) end_spaces = self.source[self.source.offset :] self.source.consume(end_spaces) if self.children: children.append(end_spaces) start = 0 if self.children: node.sorted_children = children node.region = (start, self.source.offset) self.children_stack.pop() def _handle_parens(self, children, start, formats): """Changes `children` and returns new start""" opens, closes = self._count_needed_parens(formats) old_end = self.source.offset new_end = None for i in range(closes): new_end = self.source.consume(")")[1] if new_end is not None: if self.children: children.append(self.source[old_end:new_end]) new_start = start for i in range(opens): new_start = self.source.rfind_token("(", 0, new_start) if new_start != start: if self.children: children.appendleft(self.source[new_start:start]) start = new_start return start def _eat_surrounding_parens(self, children, suspected_start, start): index = self.source.rfind_token("(", suspected_start, start) if index is not None: old_start = start old_offset = self.source.offset start = index if self.children: children.appendleft(self.source[start + 1 : old_start]) children.appendleft("(") token_start, token_end = self.source.consume(")") if self.children: children.append(self.source[old_offset:token_start]) children.append(")") return start def _count_needed_parens(self, children): start = 0 opens = 0 for child in children: if not isinstance(child, basestring): continue if child == "" or child[0] in "'\"": continue index = 0 while index < len(child): if child[index] == ")": if opens > 0: opens -= 1 else: start += 1 if child[index] == "(": opens += 1 if child[index] == "#": try: index = child.index("\n", index) except ValueError: break index += 1 return start, opens def _find_next_statement_start(self): for children in reversed(self.children_stack): for child in children: if isinstance(child, ast.stmt): return child.col_offset + self.lines.get_line_start(child.lineno) return len(self.source.source) _operators = { "And": "and", "Or": "or", "Add": "+", "Sub": "-", "Mult": "*", "Div": "/", "Mod": "%", "Pow": "**", "LShift": "<<", "RShift": ">>", "BitOr": "|", "BitAnd": "&", "BitXor": "^", "FloorDiv": "//", "Invert": "~", "Not": "not", "UAdd": "+", "USub": "-", "Eq": "==", "NotEq": "!=", "Lt": "<", "LtE": "<=", "Gt": ">", "GtE": ">=", "Is": "is", "IsNot": "is not", "In": "in", "NotIn": "not in", } def _get_op(self, node): return self._operators[node.__class__.__name__].split(" ") def _Attribute(self, node): self._handle(node, [node.value, ".", node.attr]) def _Assert(self, node): children = ["assert", node.test] if node.msg: children.append(",") children.append(node.msg) self._handle(node, children) def _Assign(self, node): children = self._child_nodes(node.targets, "=") children.append("=") children.append(node.value) self._handle(node, children) def _AugAssign(self, node): children = [node.target] children.extend(self._get_op(node.op)) children.extend(["=", node.value]) self._handle(node, children) def _AnnAssign(self, node): children = [node.target, ":", node.annotation] if node.value is not None: children.append("=") children.append(node.value) self._handle(node, children) def _Repr(self, node): self._handle(node, ["`", node.value, "`"]) def _BinOp(self, node): children = [node.left] + self._get_op(node.op) + [node.right] self._handle(node, children) def _BoolOp(self, node): self._handle(node, self._child_nodes(node.values, self._get_op(node.op)[0])) def _Break(self, node): self._handle(node, ["break"]) def _Call(self, node): def _arg_sort_key(node): if isinstance(node, ast.keyword): return (node.value.lineno, node.value.col_offset) return (node.lineno, node.col_offset) children = [node.func, "("] unstarred_args = [] starred_and_keywords = list(node.keywords) for i, arg in enumerate(node.args): if hasattr(ast, "Starred") and isinstance(arg, ast.Starred): starred_and_keywords.append(arg) else: unstarred_args.append(arg) if getattr(node, "starargs", None): starred_and_keywords.append(node.starargs) starred_and_keywords.sort(key=_arg_sort_key) children.extend(self._child_nodes(unstarred_args, ",")) # positional args come before keywords, *args comes after all # positional args, and **kwargs comes last if starred_and_keywords: if len(children) > 2: children.append(",") for i, arg in enumerate(starred_and_keywords): if arg == getattr(node, "starargs", None): children.append("*") children.append(arg) if i + 1 < len(starred_and_keywords): children.append(",") if getattr(node, "kwargs", None): if len(children) > 2: children.append(",") children.extend(["**", node.kwargs]) children.append(")") self._handle(node, children) def _ClassDef(self, node): children = [] if getattr(node, "decorator_list", None): for decorator in node.decorator_list: children.append("@") children.append(decorator) children.extend(["class", node.name]) if node.bases: children.append("(") children.extend(self._child_nodes(node.bases, ",")) children.append(")") children.append(":") children.extend(node.body) self._handle(node, children) def _Compare(self, node): children = [] children.append(node.left) for op, expr in zip(node.ops, node.comparators): children.extend(self._get_op(op)) children.append(expr) self._handle(node, children) def _Delete(self, node): self._handle(node, ["del"] + self._child_nodes(node.targets, ",")) def _Constant(self, node): if isinstance(node.value, basestring): self._handle(node, [self.String]) return if any(node.value is v for v in [True, False, None]): self._handle(node, [str(node.value)]) return if isinstance(node.value, numbers.Number): self._handle(node, [self.Number]) return if node.value is Ellipsis: self._handle(node, ["..."]) return assert False def _Num(self, node): self._handle(node, [self.Number]) def _Str(self, node): self._handle(node, [self.String]) def _Bytes(self, node): self._handle(node, [self.String]) def _JoinedStr(self, node): def start_quote_char(): possible_quotes = [ (self.source.source.find(q, start, end), q) for q in QUOTE_CHARS ] quote_pos, quote_char = min( (pos, q) for pos, q in possible_quotes if pos != -1 ) return self.source[start : quote_pos + len(quote_char)] def end_quote_char(): possible_quotes = [ (self.source.source.rfind(q, start, end), q) for q in reversed(QUOTE_CHARS) ] _, quote_pos, quote_char = max( (len(q), pos, q) for pos, q in possible_quotes if pos != -1 ) return self.source[end - len(quote_char) : end] QUOTE_CHARS = ['"""', "'''", '"', "'"] offset = self.source.offset start, end = self.source.consume_string( end=self._find_next_statement_start(), ) self.source.offset = offset children = [] children.append(start_quote_char()) for part in node.values: if isinstance(part, ast.FormattedValue): children.append(part) children.append(end_quote_char()) self._handle(node, children) def _FormattedValue(self, node): children = [] children.append("{") children.append(node.value) if node.format_spec: children.append(":") for val in node.format_spec.values: if isinstance(val, ast.FormattedValue): children.append(val.value) else: children.append(val.s) children.append("}") self._handle(node, children) def _Continue(self, node): self._handle(node, ["continue"]) def _Dict(self, node): children = [] children.append("{") if node.keys: for index, (key, value) in enumerate(zip(node.keys, node.values)): if key is None: # PEP-448 dict unpacking: {a: b, **unpack} children.extend(["**", value]) else: children.extend([key, ":", value]) if index < len(node.keys) - 1: children.append(",") children.append("}") self._handle(node, children) def _Ellipsis(self, node): self._handle(node, ["..."]) def _Expr(self, node): self._handle(node, [node.value]) def _NamedExpr(self, node): children = [node.target, ":=", node.value] self._handle(node, children) def _Exec(self, node): children = ["exec", self.exec_open_paren_or_space, node.body] if node.globals: children.extend([self.exec_in_or_comma, node.globals]) if node.locals: children.extend([",", node.locals]) children.append(self.exec_close_paren_or_space) self._handle(node, children) def _ExtSlice(self, node): children = [] for index, dim in enumerate(node.dims): if index > 0: children.append(",") children.append(dim) self._handle(node, children) def _handle_for_loop_node(self, node, is_async): if is_async: children = ["async", "for"] else: children = ["for"] children.extend([node.target, "in", node.iter, ":"]) children.extend(node.body) if node.orelse: children.extend(["else", ":"]) children.extend(node.orelse) self._handle(node, children) def _For(self, node): self._handle_for_loop_node(node, is_async=False) def _AsyncFor(self, node): self._handle_for_loop_node(node, is_async=True) def _ImportFrom(self, node): children = ["from"] if node.level: children.append("." * node.level) # see comment at rope.base.ast.walk children.extend([node.module or "", "import"]) children.extend(self._child_nodes(node.names, ",")) self._handle(node, children) def _alias(self, node): children = [node.name] if node.asname: children.extend(["as", node.asname]) self._handle(node, children) def _handle_function_def_node(self, node, is_async): children = [] try: decorators = getattr(node, "decorator_list") except AttributeError: decorators = getattr(node, "decorators", None) if decorators: for decorator in decorators: children.append("@") children.append(decorator) if is_async: children.extend(["async", "def"]) else: children.extend(["def"]) children.extend([node.name, "(", node.args]) children.extend([")", ":"]) children.extend(node.body) self._handle(node, children) def _FunctionDef(self, node): self._handle_function_def_node(node, is_async=False) def _AsyncFunctionDef(self, node): self._handle_function_def_node(node, is_async=True) def _arguments(self, node): children = [] args = list(node.args) defaults = [None] * (len(args) - len(node.defaults)) + list(node.defaults) for index, (arg, default) in enumerate(zip(args, defaults)): if index > 0: children.append(",") self._add_args_to_children(children, arg, default) if node.vararg is not None: if args: children.append(",") children.extend(["*", pycompat.get_ast_arg_arg(node.vararg)]) if node.kwarg is not None: if args or node.vararg is not None: children.append(",") children.extend(["**", pycompat.get_ast_arg_arg(node.kwarg)]) self._handle(node, children) def _add_args_to_children(self, children, arg, default): if isinstance(arg, (list, tuple)): self._add_tuple_parameter(children, arg) else: children.append(arg) if default is not None: children.append("=") children.append(default) def _add_tuple_parameter(self, children, arg): children.append("(") for index, token in enumerate(arg): if index > 0: children.append(",") if isinstance(token, (list, tuple)): self._add_tuple_parameter(children, token) else: children.append(token) children.append(")") def _GeneratorExp(self, node): children = [node.elt] children.extend(node.generators) self._handle(node, children, eat_parens=True) def _comprehension(self, node): children = ["for", node.target, "in", node.iter] if node.ifs: for if_ in node.ifs: children.append("if") children.append(if_) self._handle(node, children) def _Global(self, node): children = self._child_nodes(node.names, ",") children.insert(0, "global") self._handle(node, children) def _If(self, node): if self._is_elif(node): children = ["elif"] else: children = ["if"] children.extend([node.test, ":"]) children.extend(node.body) if node.orelse: if len(node.orelse) == 1 and self._is_elif(node.orelse[0]): pass else: children.extend(["else", ":"]) children.extend(node.orelse) self._handle(node, children) def _is_elif(self, node): if not isinstance(node, ast.If): return False offset = self.lines.get_line_start(node.lineno) + node.col_offset word = self.source[offset : offset + 4] # XXX: This is a bug; the offset does not point to the first alt_word = self.source[offset - 5 : offset - 1] return "elif" in (word, alt_word) def _IfExp(self, node): return self._handle(node, [node.body, "if", node.test, "else", node.orelse]) def _Import(self, node): children = ["import"] children.extend(self._child_nodes(node.names, ",")) self._handle(node, children) def _keyword(self, node): children = [] if node.arg is None: children.append(node.value) else: children.extend([node.arg, "=", node.value]) self._handle(node, children) def _Lambda(self, node): self._handle(node, ["lambda", node.args, ":", node.body]) def _List(self, node): self._handle(node, ["["] + self._child_nodes(node.elts, ",") + ["]"]) def _ListComp(self, node): children = ["[", node.elt] children.extend(node.generators) children.append("]") self._handle(node, children) def _Set(self, node): if node.elts: self._handle(node, ["{"] + self._child_nodes(node.elts, ",") + ["}"]) return # Python doesn't have empty set literals warnings.warn( "Tried to handle empty literal; please report!", RuntimeWarning ) self._handle(node, ["set(", ")"]) def _SetComp(self, node): children = ["{", node.elt] children.extend(node.generators) children.append("}") self._handle(node, children) def _DictComp(self, node): children = ["{"] children.extend([node.key, ":", node.value]) children.extend(node.generators) children.append("}") self._handle(node, children) def _Module(self, node): self._handle(node, list(node.body), eat_spaces=True) def _Name(self, node): self._handle(node, [node.id]) def _NameConstant(self, node): self._handle(node, [str(node.value)]) def _arg(self, node): self._handle(node, [node.arg]) def _Pass(self, node): self._handle(node, ["pass"]) def _Print(self, node): children = ["print"] if node.dest: children.extend([">>", node.dest]) if node.values: children.append(",") children.extend(self._child_nodes(node.values, ",")) if not node.nl: children.append(",") self._handle(node, children) def _Raise(self, node): def get_python3_raise_children(node): children = ["raise"] if node.exc: children.append(node.exc) if node.cause: children.append(node.cause) return children def get_python2_raise_children(node): children = ["raise"] if node.type: children.append(node.type) if node.inst: children.append(",") children.append(node.inst) if node.tback: children.append(",") children.append(node.tback) return children if pycompat.PY2: children = get_python2_raise_children(node) else: children = get_python3_raise_children(node) self._handle(node, children) def _Return(self, node): children = ["return"] if node.value: children.append(node.value) self._handle(node, children) def _Sliceobj(self, node): children = [] for index, slice in enumerate(node.nodes): if index > 0: children.append(":") if slice: children.append(slice) self._handle(node, children) def _Index(self, node): self._handle(node, [node.value]) def _Subscript(self, node): self._handle(node, [node.value, "[", node.slice, "]"]) def _Slice(self, node): children = [] if node.lower: children.append(node.lower) children.append(":") if node.upper: children.append(node.upper) if node.step: children.append(":") children.append(node.step) self._handle(node, children) def _TryFinally(self, node): # @todo fixme is_there_except_handler = False not_empty_body = True if len(node.finalbody) == 1: if pycompat.PY2: is_there_except_handler = isinstance(node.body[0], ast.TryExcept) not_empty_body = not bool(len(node.body)) elif pycompat.PY3: try: is_there_except_handler = isinstance( node.handlers[0], ast.ExceptHandler ) not_empty_body = True except IndexError: pass children = [] if not_empty_body or not is_there_except_handler: children.extend(["try", ":"]) children.extend(node.body) if pycompat.PY3: children.extend(node.handlers) children.extend(["finally", ":"]) children.extend(node.finalbody) self._handle(node, children) def _TryExcept(self, node): children = ["try", ":"] children.extend(node.body) children.extend(node.handlers) if node.orelse: children.extend(["else", ":"]) children.extend(node.orelse) self._handle(node, children) def _Try(self, node): if len(node.finalbody): self._TryFinally(node) else: self._TryExcept(node) def _ExceptHandler(self, node): self._excepthandler(node) def _excepthandler(self, node): # self._handle(node, [self.semicolon_or_as_in_except]) children = ["except"] if node.type: children.append(node.type) if node.name: children.append(self.semicolon_or_as_in_except) children.append(node.name) children.append(":") children.extend(node.body) self._handle(node, children) def _Tuple(self, node): if node.elts: self._handle(node, self._child_nodes(node.elts, ","), eat_parens=True) else: self._handle(node, [self.empty_tuple]) def _UnaryOp(self, node): children = self._get_op(node.op) children.append(node.operand) self._handle(node, children) def _Await(self, node): children = ["await"] if node.value: children.append(node.value) self._handle(node, children) def _Yield(self, node): children = ["yield"] if node.value: children.append(node.value) self._handle(node, children) def _YieldFrom(self, node): children = ["yield", "from", node.value] self._handle(node, children) def _While(self, node): children = ["while", node.test, ":"] children.extend(node.body) if node.orelse: children.extend(["else", ":"]) children.extend(node.orelse) self._handle(node, children) def _handle_with_node(self, node, is_async): children = [] if is_async: children.extend(["async"]) for item in pycompat.get_ast_with_items(node): children.extend([self.with_or_comma_context_manager, item.context_expr]) if item.optional_vars: children.extend(["as", item.optional_vars]) if pycompat.PY2 and COMMA_IN_WITH_PATTERN.search(self.source.source): children.append(node.body[0]) else: children.append(":") children.extend(node.body) self._handle(node, children) def _With(self, node): self._handle_with_node(node, is_async=False) def _AsyncWith(self, node): self._handle_with_node(node, is_async=True) def _child_nodes(self, nodes, separator): children = [] for index, child in enumerate(nodes): children.append(child) if index < len(nodes) - 1: children.append(separator) return children def _Starred(self, node): self._handle(node, [node.value]) class _Source(object): def __init__(self, source): self.source = source self.offset = 0 def consume(self, token, skip_comment=True): try: while True: new_offset = self.source.index(token, self.offset) if self._good_token(token, new_offset) or not skip_comment: break else: self._skip_comment() except (ValueError, TypeError) as e: raise MismatchedTokenError( "Token <%s> at %s cannot be matched" % (token, self._get_location()) ) self.offset = new_offset + len(token) return (new_offset, self.offset) def consume_joined_string(self, token): new_offset = self.source.index(token, self.offset) self.offset = new_offset + len(token) return (new_offset, self.offset) def consume_string(self, end=None): if _Source._string_pattern is None: string_pattern = codeanalyze.get_string_pattern() formatted_string_pattern = codeanalyze.get_formatted_string_pattern() original = r"(?:%s)|(?:%s)" % (string_pattern, formatted_string_pattern) pattern = r"(%s)((\s|\\\n|#[^\n]*\n)*(%s))*" % (original, original) _Source._string_pattern = re.compile(pattern) repattern = _Source._string_pattern return self._consume_pattern(repattern, end) def consume_number(self): if _Source._number_pattern is None: _Source._number_pattern = re.compile(self._get_number_pattern()) repattern = _Source._number_pattern return self._consume_pattern(repattern) def consume_empty_tuple(self): return self._consume_pattern(re.compile(r"\(\s*\)")) def consume_not_equal(self): if _Source._not_equals_pattern is None: _Source._not_equals_pattern = re.compile(r"<>|!=") repattern = _Source._not_equals_pattern return self._consume_pattern(repattern) def consume_except_as_or_semicolon(self): repattern = re.compile(r"as|,") return self._consume_pattern(repattern) def consume_exec_open_paren_or_space(self): repattern = re.compile(r"\(|") return self._consume_pattern(repattern) def consume_exec_in_or_comma(self): repattern = re.compile(r"in|,") return self._consume_pattern(repattern) def consume_exec_close_paren_or_space(self): repattern = re.compile(r"\)|") return self._consume_pattern(repattern) def consume_with_or_comma_context_manager(self): repattern = re.compile(r"with|,") return self._consume_pattern(repattern) def _good_token(self, token, offset, start=None): """Checks whether consumed token is in comments""" if start is None: start = self.offset try: comment_index = self.source.rindex("#", start, offset) except ValueError: return True try: new_line_index = self.source.rindex("\n", start, offset) except ValueError: return False return comment_index < new_line_index def _skip_comment(self): self.offset = self.source.index("\n", self.offset + 1) def _get_location(self): lines = self.source[: self.offset].split("\n") return (len(lines), len(lines[-1])) def _consume_pattern(self, repattern, end=None): while True: if end is None: end = len(self.source) match = repattern.search(self.source, self.offset, end) if self._good_token(match.group(), match.start()): break else: self._skip_comment() self.offset = match.end() return match.start(), match.end() def till_token(self, token): new_offset = self.source.index(token, self.offset) return self[self.offset : new_offset] def rfind_token(self, token, start, end): index = start while True: try: index = self.source.rindex(token, start, end) if self._good_token(token, index, start=start): return index else: end = index except ValueError: return None def from_offset(self, offset): return self[offset : self.offset] def find_backwards(self, pattern, offset): return self.source.rindex(pattern, 0, offset) def __getitem__(self, index): return self.source[index] def __getslice__(self, i, j): return self.source[i:j] def _get_number_pattern(self): # HACK: It is merely an approaximation and does the job integer = r"\-?(0x[\da-fA-F]+|\d+)[lL]?" return r"(%s(\.\d*)?|(\.\d+))([eE][-+]?\d+)?[jJ]?" % integer _string_pattern = None _number_pattern = None _not_equals_pattern = None