import unittest
from contextlib import contextmanager

from numba import njit
from numba.core import errors, cpu, utils, typing
from numba.core.descriptors import TargetDescriptor
from numba.core.dispatcher import TargetConfigurationStack
from numba.core.retarget import BasicRetarget
from numba.core.extending import overload
from numba.core.target_extension import (
    dispatcher_registry,
    CPUDispatcher,
    CPU,
    target_registry,
    jit_registry,
)


# ------------ A custom target ------------

CUSTOM_TARGET = ".".join([__name__, "CustomCPU"])


class CustomCPU(CPU):
    """Extend from the CPU target
    """
    pass


# Nested contexts to help with isolatings bits of compilations
class _NestedContext(object):
    _typing_context = None
    _target_context = None

    @contextmanager
    def nested(self, typing_context, target_context):
        old_nested = self._typing_context, self._target_context
        try:
            self._typing_context = typing_context
            self._target_context = target_context
            yield
        finally:
            self._typing_context, self._target_context = old_nested


# Implement a CustomCPU TargetDescriptor, this one borrows bits from the CPU
class CustomTargetDescr(TargetDescriptor):
    options = cpu.CPUTargetOptions
    _nested = _NestedContext()

    @utils.cached_property
    def _toplevel_target_context(self):
        # Lazily-initialized top-level target context, for all threads
        return cpu.CPUContext(self.typing_context, self._target_name)

    @utils.cached_property
    def _toplevel_typing_context(self):
        # Lazily-initialized top-level typing context, for all threads
        return typing.Context()

    @property
    def target_context(self):
        """
        The target context for DPU targets.
        """
        nested = self._nested._target_context
        if nested is not None:
            return nested
        else:
            return self._toplevel_target_context

    @property
    def typing_context(self):
        """
        The typing context for CPU targets.
        """
        nested = self._nested._typing_context
        if nested is not None:
            return nested
        else:
            return self._toplevel_typing_context

    def nested_context(self, typing_context, target_context):
        """
        A context manager temporarily replacing the contexts with the
        given ones, for the current thread of execution.
        """
        return self._nested.nested(typing_context, target_context)


custom_target = CustomTargetDescr(CUSTOM_TARGET)


class CustomCPUDispatcher(CPUDispatcher):
    targetdescr = custom_target


target_registry[CUSTOM_TARGET] = CustomCPU
dispatcher_registry[target_registry[CUSTOM_TARGET]] = CustomCPUDispatcher


def custom_jit(*args, **kwargs):
    assert 'target' not in kwargs
    assert '_target' not in kwargs
    return njit(*args, _target=CUSTOM_TARGET, **kwargs)


jit_registry[target_registry[CUSTOM_TARGET]] = custom_jit

# ------------ For switching target ------------


class CustomCPURetarget(BasicRetarget):
    @property
    def output_target(self):
        return CUSTOM_TARGET

    def compile_retarget(self, cpu_disp):
        kernel = njit(_target=CUSTOM_TARGET)(cpu_disp.py_func)
        return kernel


class TestRetargeting(unittest.TestCase):
    def setUp(self):
        # Generate fresh functions for each test method to avoid caching

        @njit(_target="cpu")
        def fixed_target(x):
            """
            This has a fixed target to "cpu".
            Cannot be used in CUSTOM_TARGET target.
            """
            return x + 10

        @njit
        def flex_call_fixed(x):
            """
            This has a flexible target, but uses a fixed target function.
            Cannot be used in CUSTOM_TARGET target.
            """
            return fixed_target(x) + 100

        @njit
        def flex_target(x):
            """
            This has a flexible target.
            Can be used in CUSTOM_TARGET target.
            """
            return x + 1000

        # Save these functions for use
        self.functions = locals()
        # Refresh the retarget function
        self.retarget = CustomCPURetarget()

    def switch_target(self):
        return TargetConfigurationStack.switch_target(self.retarget)

    @contextmanager
    def check_retarget_error(self):
        with self.assertRaises(errors.NumbaError) as raises:
            yield
        self.assertIn(f"{CUSTOM_TARGET} != cpu", str(raises.exception))

    def check_non_empty_cache(self):
        # Retargeting occurred. The cache must NOT be empty
        stats = self.retarget.cache.stats()
        # Because multiple function compilations are triggered, we don't know
        # precisely how many cache hit/miss there are.
        self.assertGreater(stats['hit'] + stats['miss'], 0)

    def test_case0(self):
        fixed_target = self.functions["fixed_target"]
        flex_target = self.functions["flex_target"]

        @njit
        def foo(x):
            x = fixed_target(x)
            x = flex_target(x)
            return x

        r = foo(123)
        self.assertEqual(r, 123 + 10 + 1000)
        # No retargeting occurred. The cache must be empty
        stats = self.retarget.cache.stats()
        self.assertEqual(stats, dict(hit=0, miss=0))

    def test_case1(self):
        flex_target = self.functions["flex_target"]

        @njit
        def foo(x):
            x = flex_target(x)
            return x

        with self.switch_target():
            r = foo(123)
        self.assertEqual(r, 123 + 1000)
        self.check_non_empty_cache()

    def test_case2(self):
        """
        The non-nested call into fixed_target should raise error.
        """
        fixed_target = self.functions["fixed_target"]
        flex_target = self.functions["flex_target"]

        @njit
        def foo(x):
            x = fixed_target(x)
            x = flex_target(x)
            return x

        with self.check_retarget_error():
            with self.switch_target():
                foo(123)

    def test_case3(self):
        """
        The nested call into fixed_target should raise error
        """
        flex_target = self.functions["flex_target"]
        flex_call_fixed = self.functions["flex_call_fixed"]

        @njit
        def foo(x):
            x = flex_call_fixed(x)  # calls fixed_target indirectly
            x = flex_target(x)
            return x

        with self.check_retarget_error():
            with self.switch_target():
                foo(123)

    def test_case4(self):
        """
        Same as case2 but flex_call_fixed() is invoked outside of CUSTOM_TARGET
        target before the switch_target.
        """
        flex_target = self.functions["flex_target"]
        flex_call_fixed = self.functions["flex_call_fixed"]

        r = flex_call_fixed(123)
        self.assertEqual(r, 123 + 100 + 10)

        @njit
        def foo(x):
            x = flex_call_fixed(x)  # calls fixed_target indirectly
            x = flex_target(x)
            return x

        with self.check_retarget_error():
            with self.switch_target():
                foo(123)

    def test_case5(self):
        """
        Tests overload resolution with target switching
        """

        def overloaded_func(x):
            pass

        @overload(overloaded_func, target=CUSTOM_TARGET)
        def ol_overloaded_func_custom_target(x):
            def impl(x):
                return 62830
            return impl

        @overload(overloaded_func, target='cpu')
        def ol_overloaded_func_cpu(x):
            def impl(x):
                return 31415
            return impl

        @njit
        def flex_resolve_overload(x):
            return

        @njit
        def foo(x):
            return x + overloaded_func(x)

        r = foo(123)
        self.assertEqual(r, 123 + 31415)

        with self.switch_target():
            r = foo(123)
            self.assertEqual(r, 123 + 62830)

        self.check_non_empty_cache()
