<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;">import ast
import dataclasses
import inspect
import re
import string
import sys
from collections import namedtuple
from textwrap import dedent
from typing import List, Tuple  # noqa: F401

import torch
import torch.jit.annotations
from torch import _jit_internal
from torch._C._jit_tree_views import (
    Apply,
    Assert,
    Assign,
    Attribute,
    AugAssign,
    BinOp,
    Break,
    ClassDef,
    Const,
    Continue,
    Decl,
    Def,
    Delete,
    DictComp,
    DictLiteral,
    Dots,
    EmptyTypeAnnotation,
    ExprStmt,
    FalseLiteral,
    For,
    Ident,
    If,
    ListComp,
    ListLiteral,
    NoneLiteral,
    Param,
    Pass,
    Property,
    Raise,
    Return,
    Select,
    SliceExpr,
    Starred,
    Stmt,
    StringLiteral,
    Subscript,
    TernaryIf,
    TrueLiteral,
    TupleLiteral,
    UnaryOp,
    Var,
    While,
    With,
    WithItem,
)
from torch._jit_internal import (  # noqa: F401
    _is_drop_fn,
    FunctionModifiers,
    is_static_fn,
    should_drop,
)
from torch._sources import (
    get_source_lines_and_file,
    make_source_context,
    parse_def,
    ParsedDef as _ParsedDef,
)
from torch.jit._dataclass_impls import DATACLASS_MAGIC_METHODS
from torch.jit._monkeytype_config import get_qualified_name, monkeytype_trace

_IS_ASTUNPARSE_INSTALLED = False
try:
    import astunparse  # type: ignore[import]

    _IS_ASTUNPARSE_INSTALLED = True
except ImportError:
    pass

# Borrowed from cPython implementation
# https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411#

_reserved_prefix = "__jit"
_reserved_names = {"print"}
_identifier_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits)


def is_reserved_name(name):
    return name.startswith(_reserved_prefix) or name in _reserved_names


pretty_node_names = {
    ast.FunctionDef: "function definitions",
    ast.For: "for loops",
    ast.Delete: "del statements",
    ast.ClassDef: "class definitions",
    ast.With: "with statements",
    ast.Raise: "raise statements",
    ast.Assert: "assertions",
    ast.Import: "import statements",
    ast.ImportFrom: "import statements",
    ast.Global: "global variables",
    ast.Break: "break statements",
    ast.Continue: "continue statements",
}

node_start_tokens = {
    ast.FunctionDef: "def",
    ast.For: "for",
    ast.Delete: "del",
    ast.ClassDef: "class",
    ast.With: "with",
    ast.Raise: "raise",
    ast.Assert: "assert",
    ast.Import: "import",
    ast.ImportFrom: "from",
    ast.Global: "global",
    ast.Break: "break",
    ast.Continue: "continue",
}

pretty_node_names.update(
    {
        ast.AsyncFunctionDef: "async function definitions",
        ast.AsyncFor: "async for loops",
        ast.AsyncWith: "async with statements",
        ast.Try: "try blocks",
        ast.Nonlocal: "nonlocal variables",
    }
)

node_start_tokens.update(
    {
        ast.AsyncFunctionDef: "async def",
        ast.AsyncFor: "async for",
        ast.AsyncWith: "async with",
        ast.Try: "try",
        ast.Nonlocal: "nonlocal",
    }
)

pretty_node_names.update(
    {
        ast.AnnAssign: "annotated assignments",
    }
)
# NB: no specific token for AnnAssign


class FrontendError(Exception):
    def __init__(self, source_range, msg):
        self.source_range = source_range
        self.msg = msg

        # This has to be instantiated here so the ErrorReport is accurate to the
        # call stack when the FrontendError was raised
        self.error_report = torch._C.ErrorReport(self.source_range)

    def __str__(self):
        return self.msg + self.error_report.what().lstrip()


class NotSupportedError(FrontendError):
    pass


class UnsupportedNodeError(NotSupportedError):
    def __init__(self, ctx, offending_node, reason=""):
        # If we don't have a specific token, we default to length of 1
        node_type = type(offending_node)
        range_len = len(node_start_tokens.get(node_type, " "))
        source_range = ctx.make_range(
            offending_node.lineno,
            offending_node.col_offset,
            offending_node.col_offset + range_len,
        )
        feature_name = pretty_node_names.get(node_type, node_type.__name__)
        msg = f"{feature_name} {reason + ' ' if reason else ''}aren't supported"
        super().__init__(source_range, msg)


class FrontendTypeError(FrontendError):
    pass


def build_withitems(ctx, items):
    items = [build_withitem(ctx, i) for i in items]
    return list(items)


def build_stmts(ctx, stmts):
    stmts = [build_stmt(ctx, s) for s in stmts]
    return list(filter(None, stmts))


def get_class_properties(cls, self_name):
    """
    Get a list of Property objects representing the properties of a class.

    Args:
        cls:  The class to get properties of.
        self_name: The name of the class that the properties should belong to.
    Returns:
        A list of Property objects corresponding to the properties of cls. Property
        here refers to the subclass of TreeView.
    """
    props = inspect.getmembers(cls, predicate=lambda m: isinstance(m, property))
    # Any property that should not compiled must be in this list on the Module.
    unused_properties = getattr(cls, "__jit_unused_properties__", [])

    # Create Property TreeView objects from inspected property objects.
    properties = []
    for prop in props:
        if prop[0] not in unused_properties and not should_drop(prop[1].fget):
            getter = get_jit_def(
                prop[1].fget, f"__{prop[0]}_getter", self_name=self_name
            )
            setter = (
                get_jit_def(prop[1].fset, f"__{prop[0]}_setter", self_name=self_name)
                if prop[1].fset
                else None
            )
            properties.append(
                Property(getter.range(), Ident(getter.range(), prop[0]), getter, setter)
            )

    return properties


def get_class_assigns(ctx, cls_ast):
    assigns = []

    def maybe_build_assign(builder, entry):
        nonlocal assigns
        try:
            assigns.append(builder(ctx, entry))
        except NotSupportedError:
            pass

    for entry in cls_ast.body:
        if isinstance(entry, ast.Assign):
            maybe_build_assign(StmtBuilder.build_Assign, entry)
        elif isinstance(entry, ast.AnnAssign):
            maybe_build_assign(StmtBuilder.build_AnnAssign, entry)
    return assigns


def get_jit_class_def(cls, self_name):
    # Get defs for each method within the current class independently
    # TODO: proper overriding analysis when implementing class inheritance
    methods = inspect.getmembers(
        cls,
        predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
        and not is_static_fn(cls, m.__name__)
        and m.__name__ in cls.__dict__
        and not _is_drop_fn(m),
    )

    def is_classmethod(fn):
        return inspect.ismethod(fn) and getattr(fn, "__self__", None) == cls

    # Get and parse the source code for this class
    sourcelines, file_lineno, filename = get_source_lines_and_file(
        cls, torch._C.ErrorReport.call_stack()
    )
    source = "".join(sourcelines)

    dedent_src = dedent(source)
    py_ast = ast.parse(dedent_src)

    class_ast = py_ast.body[0]
    assert isinstance(class_ast, ast.ClassDef)

    # Special case for dataclasses. In general we need access to the source code for
    # an object in order to JIT compile it. But the dataclasses module dynamically synthesizes
    # magic methods for classes, and we can't get the source code for these methods. As a
    # workaround, we synthesize TorchScript-friendly implementations ourselves.
    if dataclasses.is_dataclass(cls):
        # Detect whether the user manually implemented any of the magic methods. If they did,
        # we don't want to synthesize/override them.
        overrides = {
            method.name
            for method in class_ast.body
            if isinstance(method, ast.FunctionDef)
            and method.name in DATACLASS_MAGIC_METHODS
        }
        for i, (name, _) in enumerate(methods):
            # Is this a magic method we can synthesize?
            synthesizer_fn = DATACLASS_MAGIC_METHODS.get(name)
            if synthesizer_fn and name not in overrides:
                parsed_def = synthesizer_fn(cls)
                methods[i] = name, parsed_def
                func = getattr(cls, name)
                _jit_internal.loader.cache(func, parsed_def.source)

    method_defs = [
        get_jit_def(obj, name, self_name=self_name, is_classmethod=is_classmethod(obj))
        for (name, obj) in methods
    ]
    properties = get_class_properties(cls, self_name)

    leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
        dedent_src.split("\n", 1)[0]
    )
    ctx = make_source_context(
        source, filename, file_lineno, leading_whitespace_len, False
    )
    assigns = get_class_assigns(ctx, class_ast)

    return build_class_def(ctx, class_ast, method_defs, properties, self_name, assigns)


def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
    """
    Build a JIT AST (TreeView) from the given function.

    Args:
        fn: A function object to compile or a pre-parsed ParsedDef object
        def_name: The name to give to the resulting AST object. This is not
            always the same as `fn.__name__`, for example:
                def _forward(self):
                    ...
                forward = _forward
            In this case, the `__name__` attribute of the function object is "_forward",
            but we want the result AST to have the name "forward".
        self_name: If this function is a method, what the type name of `self` is.
    """
    parsed_def = parse_def(fn) if not isinstance(fn, _ParsedDef) else fn
    type_line = torch.jit.annotations.get_type_line(parsed_def.source)
    fn_def = parsed_def.ast.body[0]

    if is_classmethod:
        arg_name = fn_def.args.args[0].arg
        # Insert a statement that assigns the first argument to the class
        assign_stmt = ast.parse(f"{arg_name} = {self_name}").body[0]
        fn_def.body.insert(0, assign_stmt)

    # Swap out the function signature and body if it is unused
    if should_drop(fn):
        unused_fn_def = ast.parse(
            'def unused_fn(self: Any):\n\traise RuntimeError("Cannot call @unused methods")'
        )
        if len(unused_fn_def.body) != 1 or not isinstance(
            unused_fn_def.body[0], ast.FunctionDef
        ):
            raise RuntimeError(
                f"Expected a single top-level function: {parsed_def.filename}:{parsed_def.file_lineno}"
            )
        unused_def = unused_fn_def.body[0]
        fn_def.body = unused_def.body
        # kwarg/vararg not supported by `build_def`
        fn_def.args.kwarg = fn_def.args.vararg = None
        for arg in fn_def.args.args + fn_def.args.kwonlyargs:
            # Replace potentially unsupported type annotations by "Any"
            arg.annotation = unused_def.args.args[0].annotation
        if _is_drop_fn(fn):
            # Dropping potentially unsupported return type annotation for jit._drop
            fn_def.returns = None
            fn_def.type_comment = None

    # If MonkeyType is installed, get all the consolidated type traces
    # for the arguments from type_trace_db
    type_trace_db = torch.jit._script._get_type_trace_db()
    pdt_arg_types = None
    if monkeytype_trace and not isinstance(fn, _ParsedDef):
        qualname = get_qualified_name(fn)
        pdt_arg_types = type_trace_db.get_args_types(qualname)

    return build_def(
        parsed_def.ctx,
        fn_def,
        type_line,
        def_name,
        self_name=self_name,
        pdt_arg_types=pdt_arg_types,
    )


# TODO: more robust handling of recognizing ignore context manager
def is_torch_jit_ignore_context_manager(stmt):
    # checks if the statement is torch.jit.ignore context manager
    if isinstance(stmt.items[0].context_expr, ast.Call):
        # extract torch part
        function = stmt.items[0].context_expr.func
        if isinstance(function, ast.Attribute):
            attr_name = function.attr
            attr_value = function.value
            if attr_name == "_IgnoreContextManager" and isinstance(
                attr_value, ast.Attribute
            ):
                # there should be at most two nested attributes (e.g torch.jit._IgnoreContextManager)
                if attr_value.attr == "jit" and isinstance(attr_value.value, ast.Name):
                    if attr_value.value.id == "torch":
                        return True
    return False


class Builder:
    def __call__(self, ctx, node):
        method = getattr(self, "build_" + node.__class__.__name__, None)
        if method is None:
            raise UnsupportedNodeError(ctx, node)
        return method(ctx, node)


def build_class_def(ctx, py_def, methods, properties, self_name, assigns):
    r = ctx.make_range(
        py_def.lineno, py_def.col_offset, py_def.col_offset + len("class")
    )
    return ClassDef(
        Ident(r, self_name), [Stmt(method) for method in methods], properties, assigns
    )


def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=None):
    body = py_def.body
    r = ctx.make_range(py_def.lineno, py_def.col_offset, py_def.col_offset + len("def"))

    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
    return_type = None
    if getattr(py_def, "returns", None) is not None:
        return_type = build_expr(ctx, py_def.returns)

    decl = Decl(r, param_list, return_type)
    is_method = self_name is not None
    if type_line is not None:
        type_comment_decl = torch._C.parse_type_comment(type_line)
        decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)

    return Def(Ident(r, def_name), decl, build_stmts(ctx, body))


_vararg_kwarg_err = (
    "Compiled functions can't take variable number of arguments "
    "or use keyword-only arguments with defaults"
)


def build_param_list(ctx, py_args, self_name, pdt_arg_types=None):
    if py_args.kwarg is not None:
        expr = py_args.kwarg
        ctx_range = ctx.make_range(
            expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg)
        )
        raise NotSupportedError(ctx_range, _vararg_kwarg_err)
    if py_args.vararg is not None:
        expr = py_args.vararg
        ctx_range = ctx.make_range(
            expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg)
        )
        raise NotSupportedError(ctx_range, _vararg_kwarg_err)
    if len(py_args.kw_defaults) &gt; 0:
        # kw_defaults is a list of the values for the kwargs (which default to None),
        # so they don't actually have line numbers.
        for arg in py_args.kw_defaults:
            if arg is not None:
                ctx_range = build_expr(ctx, arg).range()
                raise NotSupportedError(ctx_range, _vararg_kwarg_err)

    # List of Tuple of args and type as inferred by profile directed typing
    arg_and_types = [
        (
            arg,
            pdt_arg_types[arg.arg]
            if pdt_arg_types and bool(pdt_arg_types[arg.arg])
            else None,
        )
        for arg in py_args.args
    ]
    arg_and_types_kwonlyargs = [
        (
            arg,
            pdt_arg_types[arg.arg]
            if pdt_arg_types and bool(pdt_arg_types[arg.arg])
            else None,
        )
        for arg in py_args.kwonlyargs
    ]

    result = [
        build_param(ctx, arg, self_name, kwarg_only=False, pdt_arg_type=arg_type)
        for arg, arg_type in arg_and_types
    ]
    result += [
        build_param(ctx, arg, self_name, kwarg_only=True, pdt_arg_type=arg_type)
        for arg, arg_type in arg_and_types_kwonlyargs
    ]
    return result


def build_param(ctx, py_arg, self_name, kwarg_only, pdt_arg_type=None):
    # NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
    name = py_arg.arg
    r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
    if getattr(py_arg, "annotation", None) is not None:
        annotation_expr = build_expr(ctx, py_arg.annotation)
    elif pdt_arg_type:
        annotation_expr = Var(Ident(r, pdt_arg_type))
    elif self_name is not None and name == "self":
        annotation_expr = Var(Ident(r, self_name))
    else:
        annotation_expr = EmptyTypeAnnotation(r)
    return Param(annotation_expr, Ident(r, name), kwarg_only)


def build_ignore_context_manager(ctx, stmt):
    InputType = namedtuple("InputType", ["name", "ann"])
    OutputType = namedtuple("OutputType", ["name", "ann"])

    def process_ins_outs(args):
        # parse the context manager to figure out inputs and outputs
        # with their annotated types
        # TODO: add input, output validator
        inputs = []
        outputs = []
        for arg in args:
            var_name = arg.arg
            var_ann = arg.value.value
            var_decl_type, var_ann = var_ann.split(":")
            if var_decl_type == "inp":
                inputs.append(InputType(var_name, var_ann))
            if var_decl_type == "out":
                outputs.append(OutputType(var_name, var_ann))
        return inputs, outputs

    def create_unique_name_ext(ctx, stmt):
        # extension will be based on the full path filename plus
        # the line number of original context manager
        fn = re.sub(r"[^a-zA-Z0-9_]", "_", ctx.filename)
        return f"{fn}_{stmt.lineno}"

    def build_return_ann_stmt(outputs):
        return_type_ann = ""
        return_statement_str = "return "
        if len(outputs) == 0:
            return_type_ann += " -&gt; None"
        if len(outputs) == 1:
            return_type_ann = " -&gt; " + outputs[0].ann
            return_statement_str += outputs[0].name
        if len(outputs) &gt; 1:
            return_type_ann = " -&gt; Tuple"
            return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]"
            return_statement_str += ", ".join([var.name for var in outputs])
        return return_type_ann, return_statement_str

    def build_args(args):
        return ", ".join([arg.name for arg in args])

    inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords)

    # build the replacement function str with given inputs and outputs
    ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt)
    ignore_function_str = "\ndef " + ignore_function_name
    ignore_function_str += (
        "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")"
    )

    return_ann, return_stmt = build_return_ann_stmt(outputs)
    ignore_function_str += return_ann + ": pass"

    # first create the functionDef object from just declaration
    ignore_function = ast.parse(ignore_function_str).body[0]

    # dump the body of context manager to dummy function
    ignore_function.body = stmt.body  # type: ignore[attr-defined]

    # insert return statement to the function
    return_stmt = ast.parse(return_stmt).body[0]
    ignore_function.body.append(return_stmt)  # type: ignore[attr-defined]

    # registers the custom function in the global context
    ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function)
    ignore_func_str += f'\nglobals()["{ignore_function_name}"] = {ignore_function_name}'
    exec(ignore_func_str)  # noqa: P204

    # build the statements as:
    # &lt;out_1&gt;, &lt;out_2&gt;, ... = torch.jit.frontend.&lt;func&gt;(&lt;in_1&gt;, &lt;in_2&gt;)
    assign_str_lhs = build_args(outputs)
    # this function will be registered in torch.jit.frontend module by default
    assign_str_rhs = (
        f"torch.jit.frontend.{ignore_function_name}(" + build_args(inputs) + ")"
    )

    if len(outputs) &gt; 0:
        assign_str = assign_str_lhs + " = " + assign_str_rhs
    else:
        assign_str = assign_str_rhs
    assign_ast = ast.parse(assign_str).body[0]
    return assign_ast


def get_default_args(fn):
    if fn is None:
        return {}

    signature = inspect.signature(fn)

    return {
        k: v.default
        for k, v in signature.parameters.items()
        if v.default is not inspect.Parameter.empty
    }


def get_default_args_for_class(cls):
    """
    Get default arguments for all methods in a class (except for static methods).

    Args:
        cls: type - The class type to inspect for default arguments.
    Returns:
        A Dict[str, Dict[str, Any]] which maps each method name to a Dict[str, Any]
        that maps each argument name to its default value.
    """
    # Get methods (except static methods because those are compiled separately as
    # if they were independent script functions).
    methods = inspect.getmembers(
        cls,
        predicate=lambda m: (inspect.ismethod(m) or inspect.isfunction(m))
        and not is_static_fn(cls, m.__name__)
        and m.__name__ in cls.__dict__,
    )

    # Get method defaults. Property defaults do not need to be considered
    # because setters cannot be invoked without a value.
    defaults = {
        method_name: get_default_args(method_impl)
        for method_name, method_impl in methods
    }

    return defaults


class WithItemBuilder(Builder):
    @staticmethod
    def build_withitem(ctx, item):
        lineno = item.context_expr.lineno
        start = item.context_expr.col_offset
        end = start + len(pretty_node_names[ast.With])
        op_vars = item.optional_vars
        r = ctx.make_range(lineno, start, end)

        return WithItem(
            r,
            build_expr(ctx, item.context_expr),
            build_expr(ctx, op_vars) if op_vars else None,
        )


class StmtBuilder(Builder):
    augassign_map = {
        ast.Add: "+",
        ast.Sub: "-",
        ast.Mult: "*",
        ast.Div: "/",
        ast.Mod: "%",
        ast.BitOr: "|",
        ast.BitAnd: "&amp;",
        ast.BitXor: "^",
        ast.LShift: "&lt;&lt;",
        ast.RShift: "&gt;&gt;",
        ast.Pow: "**",
    }

    @staticmethod
    def build_Expr(ctx, stmt):
        value = stmt.value
        if value.__class__.__name__ == "Str":
            # If a statement is a string literal expression,
            # then it is a docstring. Just ignore it.
            return None
        else:
            return ExprStmt(build_expr(ctx, value))

    @staticmethod
    def build_Assign(ctx, stmt):
        rhs = build_expr(ctx, stmt.value)
        lhs = [build_expr(ctx, x) for x in stmt.targets]
        return Assign(lhs, rhs)

    @staticmethod
    def build_AnnAssign(ctx, stmt):
        if stmt.value is None:
            raise UnsupportedNodeError(ctx, stmt, reason="without assigned value")

        # Disallow type annotations on instance attributes outside of __init__
        if (
            type(stmt.target) == ast.Attribute
            and stmt.target.value.id == "self"  # type: ignore[attr-defined]
            and ctx.funcname != "__init__"
        ):
            start = stmt.col_offset
            end = start + len(f"self.{stmt.target.attr}")
            if hasattr(stmt.annotation, "id"):
                end += len(f": {stmt.annotation.id}")
            sr = ctx.make_range(stmt.lineno, start, end)
            raise ValueError(
                "Type annotations on instance attributes must be declared in "
                f"__init__, not '{ctx.funcname}': {sr}"
            )

        rhs = build_expr(ctx, stmt.value)
        lhs = build_expr(ctx, stmt.target)
        the_type = build_expr(ctx, stmt.annotation)
        return Assign([lhs], rhs, the_type)

    @staticmethod
    def build_Delete(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("del"))

        return Delete(r, [build_expr(ctx, target) for target in stmt.targets])

    @staticmethod
    def build_Return(ctx, stmt):
        r = ctx.make_range(
            stmt.lineno, stmt.col_offset, stmt.col_offset + len("return")
        )
        return Return(r, None if stmt.value is None else build_expr(ctx, stmt.value))

    @staticmethod
    def build_Raise(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("raise"))
        expr = build_expr(ctx, stmt.exc)
        return Raise(r, expr)

    @staticmethod
    def build_Assert(ctx, stmt):
        r = ctx.make_range(
            stmt.lineno, stmt.col_offset, stmt.col_offset + len("assert")
        )
        test = build_expr(ctx, stmt.test)
        msg = build_expr(ctx, stmt.msg) if stmt.msg is not None else None
        return Assert(r, test, msg)

    @staticmethod
    def build_AugAssign(ctx, stmt):
        lhs = build_expr(ctx, stmt.target)
        rhs = build_expr(ctx, stmt.value)
        op = type(stmt.op)
        if op in StmtBuilder.augassign_map:
            op_token = StmtBuilder.augassign_map[op]
        else:
            raise NotSupportedError(
                find_before(ctx, rhs.range().start, "=", offsets=(-1, 0)),
                "unsupported kind of augmented assignment: " + op.__name__,
            )
        return AugAssign(lhs, op_token, rhs)

    @staticmethod
    def build_While(ctx, stmt):
        if stmt.orelse:
            # TODO: try to recover the location of else:? Python doesn't give us useful
            # annotations in this case
            raise NotSupportedError(
                None, "else branches of while loops aren't supported"
            )
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("while"))
        return While(r, build_expr(ctx, stmt.test), build_stmts(ctx, stmt.body))

    @staticmethod
    def build_For(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("for"))
        if stmt.orelse:
            raise NotSupportedError(r, "else branches of for loops aren't supported")

        return For(
            r,
            [build_expr(ctx, stmt.target)],
            [build_expr(ctx, stmt.iter)],
            build_stmts(ctx, stmt.body),
        )

    @staticmethod
    def build_If(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("if"))
        return If(
            r,
            build_expr(ctx, stmt.test),
            build_stmts(ctx, stmt.body),
            build_stmts(ctx, stmt.orelse),
        )

    @staticmethod
    def build_Print(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("print"))
        if stmt.dest:
            raise NotSupportedError(
                r, "print statements with non-default destinations aren't supported"
            )
        args = [build_expr(ctx, val) for val in stmt.values]
        return ExprStmt(Apply(Var(Ident(r, "print")), args, []))

    @staticmethod
    def build_Pass(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("pass"))
        return Pass(r)

    @staticmethod
    def build_Break(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("break"))
        return Break(r)

    @staticmethod
    def build_Continue(ctx, stmt):
        r = ctx.make_range(
            stmt.lineno, stmt.col_offset, stmt.col_offset + len("continue")
        )
        return Continue(r)

    @staticmethod
    def build_With(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("with"))
        # Handle ignore context manager
        if is_torch_jit_ignore_context_manager(stmt):
            if not _IS_ASTUNPARSE_INSTALLED:
                raise RuntimeError(
                    "torch.jit._IgnoreContextManager requires installing Python library `astunparse`,\
                                   please install it in your Python environment"
                )
            assign_ast = build_ignore_context_manager(ctx, stmt)
            return build_stmt(ctx, assign_ast)
        return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body))


class ExprBuilder(Builder):
    binop_map = {
        ast.Add: "+",
        ast.Sub: "-",
        ast.Mult: "*",
        ast.Div: "/",
        ast.Pow: "**",
        ast.Mod: "%",
        ast.FloorDiv: "//",
        ast.BitAnd: "&amp;",
        ast.BitXor: "^",
        ast.BitOr: "|",
        ast.LShift: "&lt;&lt;",
        ast.RShift: "&gt;&gt;",
    }

    binop_map[ast.MatMult] = "@"

    unop_map = {
        ast.Not: "not",
        ast.USub: "-",
        ast.Invert: "~",
    }

    boolop_map = {
        ast.And: "and",
        ast.Or: "or",
    }

    cmpop_map = {
        ast.Eq: "==",
        ast.NotEq: "!=",
        ast.LtE: "&lt;=",
        ast.Lt: "&lt;",
        ast.GtE: "&gt;=",
        ast.Gt: "&gt;",
        ast.Is: "is",
        ast.IsNot: "is not",
        ast.In: "in",
        ast.NotIn: "not in",
    }

    @staticmethod
    def build_Attribute(ctx, expr):
        base = build_expr(ctx, expr.value)
        # expr.attr is just a string, so it's not annotated in any way, so we have
        # to build the range manually
        source = ctx.source.encode("utf-8")

        def get_char(index):
            return chr(source[index])

        start_pos = base.range().end + 1
        while get_char(start_pos) in string.whitespace:  # Skip whitespace
            start_pos += 1
        end_pos = start_pos + len(expr.attr)
        name_range = ctx.make_raw_range(start_pos, end_pos)
        return Select(base, Ident(name_range, expr.attr))

    @staticmethod
    def build_Call(ctx, expr):
        func = build_expr(ctx, expr.func)
        args = [build_expr(ctx, py_arg) for py_arg in expr.args]
        if hasattr(expr, "starargs") and expr.starargs:
            stararg_expr = build_expr(ctx, expr.starargs)
            args += [Starred(stararg_expr.range(), stararg_expr)]
        kwargs = []
        for kw in expr.keywords:
            kw_expr = build_expr(ctx, kw.value)
            # XXX: we could do a better job at figuring out the range for the name here
            if not kw.arg:
                raise NotSupportedError(
                    kw_expr.range(), "keyword-arg expansion is not supported"
                )
            kwargs.append(Attribute(Ident(kw_expr.range(), kw.arg), kw_expr))
        return Apply(func, args, kwargs)

    @staticmethod
    def build_Ellipsis(ctx, expr):
        r = ctx.make_range(
            expr.lineno, expr.col_offset, expr.col_offset + 3
        )  # len("...") == 3
        return Dots(r)

    @staticmethod
    def build_Name(ctx, expr):
        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(expr.id))
        if expr.id.startswith(_reserved_prefix):
            raise NotSupportedError(
                r,
                "names of variables used in JIT-ed functions "
                "can't start with " + _reserved_prefix,
            )
        if expr.id == "True":
            return TrueLiteral(r)
        elif expr.id == "False":
            return FalseLiteral(r)
        elif expr.id == "None":
            return NoneLiteral(r)
        elif expr.id == "Ellipsis":
            return Dots(r)
        return Var(Ident(r, expr.id))

    @staticmethod
    def build_NameConstant(ctx, expr):
        r = ctx.make_range(
            expr.lineno, expr.col_offset, expr.col_offset + len(str(expr.value))
        )
        if expr.value is True:
            return TrueLiteral(r)
        elif expr.value is False:
            return FalseLiteral(r)
        elif expr.value is None:
            return NoneLiteral(r)
        elif expr.value == Ellipsis:
            return Dots(r)
        else:
            raise ValueError("Name constant value unsupported: " + str(expr.value))

    @staticmethod
    def build_BinOp(ctx, expr):
        lhs = build_expr(ctx, expr.left)
        rhs = build_expr(ctx, expr.right)
        op = type(expr.op)

        if op == ast.Div and not ctx.uses_true_division:
            err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
            raise FrontendError(
                err_range,
                "Division of ints in TorchScript uses Python 3 true "
                "division semantics. Please put `from __future__ "
                "import division` at the top of your file",
            )
        op_token = ExprBuilder.binop_map.get(op)
        if op_token is None:
            err_range = ctx.make_raw_range(lhs.range().end, rhs.range().start)
            raise NotSupportedError(
                err_range, "unsupported binary operator: " + op.__name__
            )
        return BinOp(op_token, lhs, rhs)

    @staticmethod
    def build_UnaryOp(ctx, expr):
        sub_expr = build_expr(ctx, expr.operand)
        op = type(expr.op)
        op_token = ExprBuilder.unop_map.get(op)
        if op_token is None:
            raise NotSupportedError(
                expr.range(), "unsupported unary operator: " + op.__name__
            )
        r = ctx.make_range(
            expr.lineno, expr.col_offset, expr.col_offset + len(op_token)
        )
        return UnaryOp(r, op_token, sub_expr)

    @staticmethod
    def build_BoolOp(ctx, expr):
        if len(expr.values) &lt; 2:
            raise AssertionError(
                "expected at least 2 values in BoolOp, but got " + str(len(expr.values))
            )
        sub_exprs = [build_expr(ctx, sub_expr) for sub_expr in expr.values]
        op = type(expr.op)
        op_token = ExprBuilder.boolop_map.get(op)
        if op_token is None:
            err_range = ctx.make_raw_range(
                sub_exprs[0].range().end, sub_exprs[1].range().start
            )
            raise NotSupportedError(
                err_range, "unsupported boolean operator: " + op.__name__
            )
        lhs = sub_exprs[0]
        for rhs in sub_exprs[1:]:
            lhs = BinOp(op_token, lhs, rhs)
        return lhs

    @staticmethod
    def build_IfExp(ctx, expr):
        return TernaryIf(
            build_expr(ctx, expr.test),
            build_expr(ctx, expr.body),
            build_expr(ctx, expr.orelse),
        )

    @staticmethod
    def build_Compare(ctx, expr):
        operands = [build_expr(ctx, e) for e in [expr.left] + list(expr.comparators)]
        result = None
        for lhs, op_, rhs in zip(operands, expr.ops, operands[1:]):
            op = type(op_)
            op_token = ExprBuilder.cmpop_map.get(op)
            r = ctx.make_raw_range(lhs.range().end, rhs.range().start)
            if op_token is None:
                raise NotSupportedError(
                    r, "unsupported comparison operator: " + op.__name__
                )

            if op == ast.NotIn:
                # NB: `not in` is just `not( in )`, so we don't introduce new tree view
                # but just make it a nested call in our tree view structure
                in_expr = BinOp("in", lhs, rhs)
                cmp_expr = UnaryOp(r, "not", in_expr)
            else:
                cmp_expr = BinOp(op_token, lhs, rhs)

            if result is None:
                result = cmp_expr
            else:
                result = BinOp("and", result, cmp_expr)
        return result

    @staticmethod
    def build_Subscript(ctx, expr):
        def build_SliceExpr(ctx, base, slice_expr):
            lower = (
                build_expr(ctx, slice_expr.lower)
                if slice_expr.lower is not None
                else None
            )
            upper = (
                build_expr(ctx, slice_expr.upper)
                if slice_expr.upper is not None
                else None
            )
            step = (
                build_expr(ctx, slice_expr.step)
                if slice_expr.step is not None
                else None
            )
            return SliceExpr(base.range(), lower, upper, step)

        def build_Index(ctx, base, index_expr):
            if isinstance(index_expr.value, ast.Tuple):
                raise NotSupportedError(
                    base.range(),
                    "slicing multiple dimensions with " "tuples not supported yet",
                )
            return build_expr(ctx, index_expr.value)

        def build_ExtSlice(ctx, base, extslice):
            sub_exprs = []
            for expr in extslice.dims:
                sub_type = type(expr)
                if sub_type is ast.Index:
                    sub_exprs.append(build_Index(ctx, base, expr))
                elif sub_type is ast.Slice:
                    sub_exprs.append(build_SliceExpr(ctx, base, expr))
                elif sub_type is ast.Ellipsis:
                    sub_exprs.append(Dots(base.range()))
                else:
                    raise NotSupportedError(
                        base.range(),
                        f"slicing multiple dimensions with {sub_type} not supported",
                    )
            return sub_exprs

        base = build_expr(ctx, expr.value)
        sub_type = type(expr.slice)
        if sub_type is ast.Index:
            if isinstance(expr.slice.value, ast.Tuple):
                # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k]
                # XXX: Indexing using a list is **different**! It triggers advanced indexing.
                indices = [
                    build_expr(ctx, index_expr) for index_expr in expr.slice.value.elts
                ]
                if not indices:
                    # `col_offset` is an int, but `end_col_offset` is
                    # `Optional[int]`. The magic number is here to make
                    # sure we can parse `()` on any machine
                    r = ctx.make_range(
                        expr.lineno,
                        expr.slice.value.col_offset,
                        expr.slice.value.col_offset + 2,
                    )
                    tup = TupleLiteral(r, [])
                    indices.append(tup)
                return Subscript(base, indices)
            else:
                return Subscript(base, [build_expr(ctx, expr.slice.value)])
        elif sub_type is ast.Slice:
            return Subscript(base, [build_SliceExpr(ctx, base, expr.slice)])
        elif sub_type is ast.ExtSlice:
            return Subscript(base, build_ExtSlice(ctx, base, expr.slice))
        elif sys.version_info &gt;= (
            3,
            9,
        ):  # In Python3.9 array indicies are not wrapped in ast.Index
            if sub_type is ast.Tuple:
                # N-dimensional indexing using Tuple: x[(i, j, k)] is equivalent to x[i, j, k]
                indices = []
                for index_expr in expr.slice.elts:
                    if isinstance(index_expr, ast.Slice):
                        indices.append(build_SliceExpr(ctx, base, index_expr))
                    else:
                        indices.append(build_expr(ctx, index_expr))
                # Special-case logic for `typing.Tuple[()]`
                if not indices:
                    # See note above r.e. magic number
                    r = ctx.make_range(
                        expr.lineno, expr.slice.col_offset, expr.slice.col_offset + 2
                    )
                    tup = TupleLiteral(r, [])
                    indices.append(tup)
                return Subscript(base, indices)
            return Subscript(base, [build_expr(ctx, expr.slice)])
        else:  # Ellipsis (can only happen in Python 2)
            raise NotSupportedError(base.range(), "ellipsis is not supported")

    @staticmethod
    def build_List(ctx, expr):
        return ListLiteral(
            ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
            [build_expr(ctx, e) for e in expr.elts],
        )

    @staticmethod
    def build_Tuple(ctx, expr):
        return TupleLiteral(
            ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1),
            [build_expr(ctx, e) for e in expr.elts],
        )

    @staticmethod
    def build_Dict(ctx, expr):
        range = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
        if expr.keys and not expr.keys[0]:
            raise NotSupportedError(
                range, "Dict expansion (e.g. `{**dict}`) is not supported"
            )
        return DictLiteral(
            range,
            [build_expr(ctx, e) for e in expr.keys],
            [build_expr(ctx, e) for e in expr.values],
        )

    @staticmethod
    def build_Num(ctx, expr):
        value = str(expr.n)
        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + len(value))
        return Const(r, value)

    @staticmethod
    def build_Constant(ctx, expr):
        value = expr.value
        if value is None or isinstance(value, bool):
            # NB: this check has to happen before the int check because bool is
            # a subclass of int
            return ExprBuilder.build_NameConstant(ctx, expr)
        if isinstance(value, (int, float, complex)):
            return ExprBuilder.build_Num(ctx, expr)
        elif isinstance(value, str):
            return ExprBuilder.build_Str(ctx, expr)
        elif isinstance(value, type(Ellipsis)):
            return ExprBuilder.build_Ellipsis(ctx, expr)
        else:
            error_range = ctx.make_range(
                expr.lineno, expr.col_offset, expr.col_offset + len(str(value))
            )
            raise FrontendError(error_range, "Unknown Constant expression type")

    @staticmethod
    def build_Str(ctx, expr):
        value = str(expr.s)
        r = ctx.make_range(
            expr.lineno, expr.col_offset, expr.col_offset + len(value) + 1
        )
        return StringLiteral(r, value)

    @staticmethod
    def build_JoinedStr(ctx, expr):
        s = ""
        args = []
        for value in expr.values:
            r = ctx.make_range(value.lineno, value.col_offset, value.col_offset + 1)
            if isinstance(value, ast.FormattedValue):
                if value.conversion != -1:
                    raise NotSupportedError(r, "Don't support conversion in JoinedStr")
                if value.format_spec is not None:
                    raise NotSupportedError(r, "Don't support formatting in JoinedStr")
                s += "{}"
                args.append(build_expr(ctx, value.value))
            elif isinstance(value, ast.Str):
                s += value.s
            else:
                raise NotSupportedError(r, "Unsupported value in JoinedStr")

        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
        return Apply(Select(StringLiteral(r, s), Ident(r, "format")), args, [])

    @staticmethod
    def build_ListComp(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
        if len(stmt.generators) != 1:
            raise NotSupportedError(r, "Only a single generator is currently supported")

        if len(stmt.generators[0].ifs) != 0:
            raise NotSupportedError(r, "Comprehension ifs are not supported yet")

        elt_expr = build_expr(ctx, stmt.elt)
        target_expr = build_expr(ctx, stmt.generators[0].target)
        iter_expr = build_expr(ctx, stmt.generators[0].iter)

        return ListComp(r, elt_expr, target_expr, iter_expr)

    @staticmethod
    def build_GeneratorExp(ctx, stmt):
        # Convert Generator expression to ListComp
        return ExprBuilder.build_ListComp(ctx, stmt)

    @staticmethod
    def build_DictComp(ctx, stmt):
        r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset)
        if len(stmt.generators) != 1:
            raise NotSupportedError(r, "Only a single generator is currently supported")

        if len(stmt.generators[0].ifs) != 0:
            raise NotSupportedError(r, "Comprehension ifs are not supported yet")

        key_expr = build_expr(ctx, stmt.key)
        value_expr = build_expr(ctx, stmt.value)
        target_expr = build_expr(ctx, stmt.generators[0].target)
        iter_expr = build_expr(ctx, stmt.generators[0].iter)

        return DictComp(r, key_expr, value_expr, target_expr, iter_expr)

    @staticmethod
    def build_Starred(ctx, expr):
        r = ctx.make_range(expr.lineno, expr.col_offset, expr.col_offset + 1)
        return Starred(r, build_expr(ctx, expr.value))


build_expr = ExprBuilder()
build_stmt = StmtBuilder()
build_withitem = WithItemBuilder()


def find_before(ctx, pos, substr, offsets=(0, 0)):
    new_pos = ctx.source[:pos].rindex(substr)
    return ctx.make_raw_range(new_pos + offsets[0], new_pos + len(substr) + offsets[1])
</pre></body></html>