"""
Test mutable struct, aka, structref
"""
import warnings

import numpy as np

from numba import typed, njit, errors
from numba.core import types
from numba.experimental import structref
from numba.extending import overload_method, overload_attribute
from numba.tests.support import (
    MemoryLeakMixin, TestCase, temp_directory, override_config,
)


@structref.register
class MySimplerStructType(types.StructRef):
    """
    Test associated with this type represent the lowest level uses of structref.
    """
    pass


my_struct_ty = MySimplerStructType(
    fields=[("values", types.intp[:]), ("counter", types.intp)]
)

structref.define_boxing(MySimplerStructType, structref.StructRefProxy)


class MyStruct(structref.StructRefProxy):

    def __new__(cls, values, counter):
        # Define this method to customize the constructor.
        # The default takes `*args`. Customizing allow the use of keyword-arg.
        # The impl of the method calls `StructRefProxy.__new__`
        return structref.StructRefProxy.__new__(cls, values, counter)

    # The below defines wrappers for attributes and methods manually

    @property
    def values(self):
        return get_values(self)

    @values.setter
    def values(self, val):
        return set_values(self, val)

    @property
    def counter(self):
        return get_counter(self)

    def testme(self, arg):
        return self.values * arg + self.counter

    @property
    def prop(self):
        return self.values, self.counter


@structref.register
class MyStructType(types.StructRef):
    """Test associated with this type represent the higher-level uses of
    structef.
    """
    pass


# Call to define_proxy is needed to register the use of `MyStruct` as a
# PyObject proxy for creating a Numba-allocated structref.
# The `MyStruct` class can then be used in both jit-code and interpreted-code.
structref.define_proxy(
    MyStruct,
    MyStructType,
    ['values', 'counter'],
)


@njit
def my_struct(values, counter):
    st = structref.new(my_struct_ty)
    my_struct_init(st, values, counter)
    return st


@njit
def my_struct_init(self, values, counter):
    self.values = values
    self.counter = counter


@njit
def ctor_by_intrinsic(vs, ctr):
    st = my_struct(vs, counter=ctr)
    st.values += st.values
    st.counter *= ctr
    return st


@njit
def ctor_by_class(vs, ctr):
    return MyStruct(values=vs, counter=ctr)


@njit
def get_values(st):
    return st.values


@njit
def set_values(st, val):
    st.values = val


@njit
def get_counter(st):
    return st.counter


@njit
def compute_fields(st):
    return st.values + st.counter


class TestStructRefBasic(MemoryLeakMixin, TestCase):
    def test_structref_type(self):
        sr = types.StructRef([('a', types.int64)])
        self.assertEqual(sr.field_dict['a'], types.int64)
        sr = types.StructRef([('a', types.int64), ('b', types.float64)])
        self.assertEqual(sr.field_dict['a'], types.int64)
        self.assertEqual(sr.field_dict['b'], types.float64)
        # bad case
        with self.assertRaisesRegex(ValueError,
                                    "expecting a str for field name"):
            types.StructRef([(1, types.int64)])
        with self.assertRaisesRegex(ValueError,
                                    "expecting a Numba Type for field type"):
            types.StructRef([('a', 123)])

    def test_invalid_uses(self):
        with self.assertRaisesRegex(ValueError, "cannot register"):
            structref.register(types.StructRef)
        with self.assertRaisesRegex(ValueError, "cannot register"):
            structref.define_boxing(types.StructRef, MyStruct)

    def test_MySimplerStructType(self):
        vs = np.arange(10, dtype=np.intp)
        ctr = 13

        first_expected = vs + vs
        first_got = ctor_by_intrinsic(vs, ctr)
        # the returned instance is a structref.StructRefProxy
        # but not a MyStruct
        self.assertNotIsInstance(first_got, MyStruct)
        self.assertPreciseEqual(first_expected, get_values(first_got))

        second_expected = first_expected + (ctr * ctr)
        second_got = compute_fields(first_got)
        self.assertPreciseEqual(second_expected, second_got)

    def test_MySimplerStructType_wrapper_has_no_attrs(self):
        vs = np.arange(10, dtype=np.intp)
        ctr = 13
        wrapper = ctor_by_intrinsic(vs, ctr)
        self.assertIsInstance(wrapper, structref.StructRefProxy)
        with self.assertRaisesRegex(AttributeError, 'values'):
            wrapper.values
        with self.assertRaisesRegex(AttributeError, 'counter'):
            wrapper.counter

    def test_MyStructType(self):
        vs = np.arange(10, dtype=np.float64)
        ctr = 11

        first_expected_arr = vs.copy()
        first_got = ctor_by_class(vs, ctr)
        self.assertIsInstance(first_got, MyStruct)
        self.assertPreciseEqual(first_expected_arr, first_got.values)

        second_expected = first_expected_arr + ctr
        second_got = compute_fields(first_got)
        self.assertPreciseEqual(second_expected, second_got)
        self.assertEqual(first_got.counter, ctr)

    def test_MyStructType_mixed_types(self):
        # structref constructor is generic
        @njit
        def mixed_type(x, y, m, n):
            return MyStruct(x, y), MyStruct(m, n)

        a, b = mixed_type(1, 2.3, 3.4j, (4,))
        self.assertEqual(a.values, 1)
        self.assertEqual(a.counter, 2.3)
        self.assertEqual(b.values, 3.4j)
        self.assertEqual(b.counter, (4,))

    def test_MyStructType_in_dict(self):
        td = typed.Dict()
        td['a'] = MyStruct(1, 2.3)
        self.assertEqual(td['a'].values, 1)
        self.assertEqual(td['a'].counter, 2.3)
        # overwrite
        td['a'] = MyStruct(2, 3.3)
        self.assertEqual(td['a'].values, 2)
        self.assertEqual(td['a'].counter, 3.3)
        # mutate
        td['a'].values += 10
        self.assertEqual(td['a'].values, 12)    # changed
        self.assertEqual(td['a'].counter, 3.3)  # unchanged
        # insert
        td['b'] = MyStruct(4, 5.6)

    def test_MyStructType_in_dict_mixed_type_error(self):
        self.disable_leak_check()

        td = typed.Dict()
        td['a'] = MyStruct(1, 2.3)
        self.assertEqual(td['a'].values, 1)
        self.assertEqual(td['a'].counter, 2.3)

        # ERROR: store different types
        with self.assertRaisesRegex(errors.TypingError,
                                    r"Cannot cast numba.MyStructType"):
            # because first field is not a float;
            # the second field is now an integer.
            td['b'] = MyStruct(2.3, 1)


@overload_method(MyStructType, "testme")
def _ol_mystructtype_testme(self, arg):
    def impl(self, arg):
        return self.values * arg + self.counter
    return impl


@overload_attribute(MyStructType, "prop")
def _ol_mystructtype_prop(self):
    def get(self):
        return self.values, self.counter
    return get


class TestStructRefExtending(MemoryLeakMixin, TestCase):
    def test_overload_method(self):
        @njit
        def check(x):
            vs = np.arange(10, dtype=np.float64)
            ctr = 11
            obj = MyStruct(vs, ctr)
            return obj.testme(x)

        x = 3
        got = check(x)
        expect = check.py_func(x)
        self.assertPreciseEqual(got, expect)

    def test_overload_attribute(self):
        @njit
        def check():
            vs = np.arange(10, dtype=np.float64)
            ctr = 11
            obj = MyStruct(vs, ctr)
            return obj.prop

        got = check()
        expect = check.py_func()
        self.assertPreciseEqual(got, expect)


def caching_test_make(x, y):
    struct = MyStruct(values=x, counter=y)
    return struct


def caching_test_use(struct, z):
    return struct.testme(z)


class TestStructRefCaching(MemoryLeakMixin, TestCase):
    def setUp(self):
        self._cache_dir = temp_directory(TestStructRefCaching.__name__)
        self._cache_override = override_config('CACHE_DIR', self._cache_dir)
        self._cache_override.__enter__()
        warnings.simplefilter("error")
        warnings.filterwarnings(action="ignore", module="typeguard")

    def tearDown(self):
        self._cache_override.__exit__(None, None, None)
        warnings.resetwarnings()

    def test_structref_caching(self):
        def assert_cached(stats):
            self.assertEqual(len(stats.cache_hits), 1)
            self.assertEqual(len(stats.cache_misses), 0)

        def assert_not_cached(stats):
            self.assertEqual(len(stats.cache_hits), 0)
            self.assertEqual(len(stats.cache_misses), 1)

        def check(cached):
            check_make = njit(cache=True)(caching_test_make)
            check_use = njit(cache=True)(caching_test_use)
            vs = np.random.random(3)
            ctr = 17
            factor = 3
            st = check_make(vs, ctr)
            got = check_use(st, factor)
            expect = vs * factor + ctr
            self.assertPreciseEqual(got, expect)

            if cached:
                assert_cached(check_make.stats)
                assert_cached(check_use.stats)
            else:
                assert_not_cached(check_make.stats)
                assert_not_cached(check_use.stats)

        check(cached=False)
        check(cached=True)


@structref.register
class PolygonStructType(types.StructRef):

    def preprocess_fields(self, fields):
        # temp name to allow Optional instantiation
        self.name = f"numba.PolygonStructType#{id(self)}"
        fields = tuple([
            ('value', types.Optional(types.int64)),
            ('parent', types.Optional(self)),
        ])

        return fields


polygon_struct_type = PolygonStructType(fields=(
    ('value', types.Any),
    ('parent', types.Any)
))


class PolygonStruct(structref.StructRefProxy):
    def __new__(cls, value, parent):
        return structref.StructRefProxy.__new__(cls, value, parent)

    @property
    def value(self):
        return PolygonStruct_get_value(self)

    @property
    def parent(self):
        return PolygonStruct_get_parent(self)


@njit
def PolygonStruct_get_value(self):
    return self.value


@njit
def PolygonStruct_get_parent(self):
    return self.parent


structref.define_proxy(
    PolygonStruct,
    PolygonStructType,
    ["value", "parent"]
)


@overload_method(PolygonStructType, "flip")
def _ol_polygon_struct_flip(self):
    def impl(self):
        if self.value is not None:
            self.value = -self.value
    return impl


@overload_attribute(PolygonStructType, "prop")
def _ol_polygon_struct_prop(self):
    def get(self):
        return self.value, self.parent
    return get


class TestStructRefForwardTyping(MemoryLeakMixin, TestCase):
    def test_same_type_assignment(self):
        @njit
        def check(x):
            poly = PolygonStruct(None, None)
            p_poly = PolygonStruct(None, None)
            poly.value = x
            poly.parent = p_poly
            p_poly.value = x
            return poly.parent.value

        x = 11
        got = check(x)
        expect = x
        self.assertPreciseEqual(got, expect)

    def test_overload_method(self):
        @njit
        def check(x):
            poly = PolygonStruct(None, None)
            p_poly = PolygonStruct(None, None)
            poly.value = x
            poly.parent = p_poly
            p_poly.value = x
            poly.flip()
            poly.parent.flip()
            return poly.parent.value

        x = 3
        got = check(x)
        expect = -x
        self.assertPreciseEqual(got, expect)

    def test_overload_attribute(self):
        @njit
        def check():
            obj = PolygonStruct(5, None)
            return obj.prop[0]

        got = check()
        expect = 5
        self.assertPreciseEqual(got, expect)
