from typing import List, Union, Tuple, Optional
from torchgen.model import (
    Type,
    BaseTy,
    BaseType,
    OptionalType,
    ListType,
    OperatorName,
    FunctionSchema,
    Return,
    TensorOptionsArguments,
    Argument,
)
from torchgen.api.types import (
    CType,
    BaseCppType,
    BaseCType,
    OptionalCType,
    NamedCType,
    deviceT,
    layoutT,
    VectorCType,
    boolT,
    longT,
    doubleT,
    ListCType,
    stringT,
    scalarT,
    scalarTypeT,
    memoryFormatT,
    SymIntT,
)


_valueT = None


def getValueT() -> BaseCppType:
    global _valueT
    if not _valueT:
        raise NotImplementedError(
            "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
        )

    return _valueT


def setValueT(val: BaseCppType) -> None:
    global _valueT
    _valueT = val


# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
# making it easier to represent special properties of an arg.
tensorListValueT = BaseCppType("torch::lazy", "Value")


def process_ir_type(
    typ: Type,
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
    """
    This function takes a type from NativeFunctions and converts it for use with
    lazy tensor codegen.

    Type conversion for lazy currently consists of
     (1) changing at::Tensors into lazy::Values
     (2) wrapping everything in a BaseCType
     (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)

    (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
    There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'

    This is incomplete- there are assertions in places that it's expected to need to add
    more types as the codegen is used with more operators.
    """
    if isinstance(typ, BaseType):
        if typ.name == BaseTy.Tensor:
            return BaseCType(getValueT())
        elif typ.name == BaseTy.Scalar:
            # at::scalar has special handling,
            # and is wrapped in an lazy::Value just like at::tensor
            return BaseCType(getValueT())
        elif typ.name == BaseTy.ScalarType:
            return BaseCType(scalarTypeT)
        elif typ.name == BaseTy.int:
            return BaseCType(longT)
        elif typ.name == BaseTy.SymInt:
            return BaseCType(getValueT())
        elif typ.name == BaseTy.bool:
            return BaseCType(boolT)
        elif typ.name == BaseTy.float:
            return BaseCType(doubleT)
        elif typ.name == BaseTy.str:
            return BaseCType(stringT)
        elif typ.name == BaseTy.Device:
            return BaseCType(deviceT)
        elif typ.name == BaseTy.Layout:
            return BaseCType(layoutT)
        elif typ.name == BaseTy.MemoryFormat:
            return BaseCType(memoryFormatT)
        else:
            raise AssertionError(f"TODO add support for type {repr(typ)}")
    elif isinstance(typ, OptionalType):
        return OptionalCType(process_ir_type(typ.elem))
    elif isinstance(typ, ListType):
        if str(typ.elem) == "Tensor?":
            # TODO(whc) is this actually correct? or should it use a Vector like above
            return ListCType(OptionalCType(BaseCType(getValueT())))
        elif str(typ.elem) == "Tensor":
            # this is a TensorList which comes in from GetTensorList as a Value
            return BaseCType(tensorListValueT)
        else:
            return VectorCType(process_ir_type(typ.elem))
    else:
        raise AssertionError(f"unrecognized type {repr(typ)}")


def isValueType(typ: CType) -> bool:
    """
    Given a type, determine if it is a Value-like type.  This is equivalent to
    being Tensor-like, but assumes the type has already been transformed.
    """
    if isinstance(typ, BaseCType):
        # I am regretting my naming conventions, but now we are wrapping at::scalar in
        # lazy value, while preserving other 'scalar' types as scalars in the IR
        return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
    elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
        return isValueType(typ.elem)
    return False


def isSymIntType(typ: Type) -> bool:
    return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt


def isWrappedScalarType(typ: Type) -> bool:
    """
    Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
    Since we literally change the type from scalarT to valueT, information is lost.
    This function helps build a list of wrapped scalars to save that information
    """
    if isinstance(typ, BaseType):
        # I am regretting my naming conventions, but now we are wrapping at::scalar in
        # lazy value, while preserving other 'scalar' types as scalars in the IR
        return typ.name == BaseTy.Scalar
    elif isinstance(typ, (OptionalType, ListType)):
        return isWrappedScalarType(typ.elem)
    return False


def isGeneratorType(typ: Type) -> bool:
    if isinstance(typ, BaseType):
        return typ.name == BaseTy.Generator
    elif isinstance(typ, (OptionalType)):
        return isGeneratorType(typ.elem)
    return False


class LazyArgument:
    name: str
    orig_type: Type
    lazy_type_: Optional[CType]
    is_wrapped_scalar: bool
    is_generator: bool
    is_symint_or_list: bool

    # true if this argument is or contains a lazy IR value
    is_lazy_value: bool

    def __init__(self, arg: Argument):
        self.name = arg.name
        self.orig_type = arg.type
        self.is_optional = isinstance(arg.type, OptionalType)
        self.is_generator = isGeneratorType(arg.type)
        if self.is_generator:
            assert (
                self.is_optional
            ), "We expect all generators are optional since currently they are"
            # there is no handling for generators in TorchScript IR (or XLA)
            # so we fall back to eager if the (optional)generator has value, and otherwise
            # its null and safe to exclude from lazy IR
            self.lazy_type_ = None
        else:
            self.lazy_type_ = process_ir_type(arg.type)
        self.is_wrapped_scalar = isWrappedScalarType(arg.type)
        self.is_symint_or_list = isSymIntType(arg.type)

        self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type)

    @property
    def lazy_type(self) -> CType:
        assert (
            self.lazy_type_ is not None
        ), f"Attempted to access lazy_type for invalid argument {self.name}"
        return self.lazy_type_


# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
# but carries type information from a native FunctionSchema modified for use with IR nodes,
# and preserving original argument names.
class LazyIrSchema:
    # The name of the operator this function schema describes.
    name: "OperatorName"

    positional_args: Tuple[LazyArgument, ...]
    keyword_args: Tuple[LazyArgument, ...]

    # TODO: Need to handle collisions with argument names at some point
    returns: Tuple["Return", ...]

    # if this schema has a Generator arg, list its orig ctype/name but don't
    # build a LazyArgument since lazy IR doesn't support it
    generator_arg: Optional[NamedCType] = None

    def __init__(self, func: FunctionSchema):

        positional_args = []
        for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
            if arg_field == "self_arg" and func.arguments.self_arg is not None:
                arg = getattr(func.arguments, "self_arg").argument
                positional_args.append(LazyArgument(arg))
            elif getattr(func.arguments, arg_field) is not None:
                positional_args.extend(
                    [LazyArgument(arg) for arg in getattr(func.arguments, arg_field)]
                )
        self.positional_args = tuple(positional_args)

        keyword_args = []
        for arg_field in [
            "pre_tensor_options_kwarg_only",
            "tensor_options",
            "post_tensor_options_kwarg_only",
            "out",
        ]:
            curr_args = getattr(func.arguments, arg_field)
            if curr_args is not None:
                if isinstance(curr_args, TensorOptionsArguments):
                    curr_args = curr_args.all()
                for arg in curr_args:
                    if isGeneratorType(arg.type):
                        assert (
                            self.generator_arg is None
                        ), "We expect there is only one generator arg"
                        self.generator_arg = NamedCType(arg.name, arg.type)
                keyword_args.extend([LazyArgument(arg) for arg in curr_args])
        self.keyword_args = tuple(keyword_args)
        self.name = func.name
        self.returns = func.returns

    @property
    def node_name(self) -> str:
        """
        Return camel-case version of op in node.

        Note: This function also appends any `overload_name` in the operation.
        For example, if the op is `bitwise_and.Tensor`, the returned name
        will be `BitwiseAndTensor`.
        """
        op_name = f"{self.name.name}_{self.name.overload_name}".lower()
        return "".join(word.capitalize() or "" for word in op_name.split("_"))

    @property
    def aten_name(self) -> str:
        return f"{self.name.name}"

    @property
    def base_name(self) -> str:
        return f"{self.name.name.base}"

    def filtered_args(
        self,
        positional: bool = True,
        keyword: bool = True,
        values: bool = True,
        scalars: bool = True,
        generator: bool = False,
    ) -> List[LazyArgument]:
        # This function maintains the sorted order of arguments but provides different filtered views.
        # Some parts of the code care about kwargs vs args (TS lowerings),
        # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
        # Generators are special cased, as they are needed for fallback/shape-inference but not supported
        # in TS lowerings and therefore also omitted from lazy IR.
        args: List[LazyArgument] = []
        if positional:
            args.extend(self.positional_args)
        if keyword:
            args.extend(self.keyword_args)

        if values and scalars and generator:
            return args
        elif values and scalars:
            return [a for a in args if not a.is_generator]
        elif values:
            return [a for a in args if a.is_lazy_value]
        elif scalars:
            return [
                a
                for a in args
                if not a.is_lazy_value and (generator or not a.is_generator)
            ]

        return []

    @property
    def positional_values(self) -> List[LazyArgument]:
        return self.filtered_args(
            positional=True, keyword=False, values=True, scalars=False
        )

    @property
    def positional_scalars(self) -> List[LazyArgument]:
        return self.filtered_args(
            positional=True, keyword=False, values=False, scalars=True
        )

    @property
    def keyword_values(self) -> List[LazyArgument]:
        return self.filtered_args(
            positional=False, keyword=True, values=True, scalars=False
        )

    @property
    def keyword_scalars(self) -> List[LazyArgument]:
        return self.filtered_args(
            positional=False, keyword=True, values=False, scalars=True
        )
