import numpy as np
import pytest
import warnings

import pickle

from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.metaestimators import available_if


class Prefix:
    def func(self):
        pass


class MockMetaEstimator:
    """This is a mock meta estimator"""

    a_prefix = Prefix()

    @if_delegate_has_method(delegate="a_prefix")
    def func(self):
        """This is a mock delegated function"""
        pass


@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated")
def test_delegated_docstring():
    assert "This is a mock delegated function" in str(
        MockMetaEstimator.__dict__["func"].__doc__
    )
    assert "This is a mock delegated function" in str(MockMetaEstimator.func.__doc__)
    assert "This is a mock delegated function" in str(MockMetaEstimator().func.__doc__)


class MetaEst:
    """A mock meta estimator"""

    def __init__(self, sub_est, better_sub_est=None):
        self.sub_est = sub_est
        self.better_sub_est = better_sub_est

    @if_delegate_has_method(delegate="sub_est")
    def predict(self):
        pass


class MetaEstTestTuple(MetaEst):
    """A mock meta estimator to test passing a tuple of delegates"""

    @if_delegate_has_method(delegate=("sub_est", "better_sub_est"))
    def predict(self):
        pass


class MetaEstTestList(MetaEst):
    """A mock meta estimator to test passing a list of delegates"""

    @if_delegate_has_method(delegate=["sub_est", "better_sub_est"])
    def predict(self):
        pass


class HasPredict:
    """A mock sub-estimator with predict method"""

    def predict(self):
        pass


class HasNoPredict:
    """A mock sub-estimator with no predict method"""

    pass


class HasPredictAsNDArray:
    """A mock sub-estimator where predict is a NumPy array"""

    predict = np.ones((10, 2), dtype=np.int64)


@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated")
def test_if_delegate_has_method():
    assert hasattr(MetaEst(HasPredict()), "predict")
    assert not hasattr(MetaEst(HasNoPredict()), "predict")
    assert not hasattr(MetaEstTestTuple(HasNoPredict(), HasNoPredict()), "predict")
    assert hasattr(MetaEstTestTuple(HasPredict(), HasNoPredict()), "predict")
    assert not hasattr(MetaEstTestTuple(HasNoPredict(), HasPredict()), "predict")
    assert not hasattr(MetaEstTestList(HasNoPredict(), HasPredict()), "predict")
    assert hasattr(MetaEstTestList(HasPredict(), HasPredict()), "predict")


class AvailableParameterEstimator:
    """This estimator's `available` parameter toggles the presence of a method"""

    def __init__(self, available=True, return_value=1):
        self.available = available
        self.return_value = return_value

    @available_if(lambda est: est.available)
    def available_func(self):
        """This is a mock available_if function"""
        return self.return_value


def test_available_if_docstring():
    assert "This is a mock available_if function" in str(
        AvailableParameterEstimator.__dict__["available_func"].__doc__
    )
    assert "This is a mock available_if function" in str(
        AvailableParameterEstimator.available_func.__doc__
    )
    assert "This is a mock available_if function" in str(
        AvailableParameterEstimator().available_func.__doc__
    )


def test_available_if():
    assert hasattr(AvailableParameterEstimator(), "available_func")
    assert not hasattr(AvailableParameterEstimator(available=False), "available_func")


def test_available_if_unbound_method():
    # This is a non regression test for:
    # https://github.com/scikit-learn/scikit-learn/issues/20614
    # to make sure that decorated functions can be used as an unbound method,
    # for instance when monkeypatching.
    est = AvailableParameterEstimator()
    AvailableParameterEstimator.available_func(est)

    est = AvailableParameterEstimator(available=False)
    with pytest.raises(
        AttributeError,
        match="This 'AvailableParameterEstimator' has no attribute 'available_func'",
    ):
        AvailableParameterEstimator.available_func(est)


@pytest.mark.filterwarnings("ignore:if_delegate_has_method was deprecated")
def test_if_delegate_has_method_numpy_array():
    """Check that we can check for an attribute that is a NumPy array.

    This is a non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/21144
    """
    estimator = MetaEst(HasPredictAsNDArray())
    assert hasattr(estimator, "predict")


def test_if_delegate_has_method_deprecated():
    """Check the deprecation warning of if_delegate_has_method"""
    # don't warn when creating the decorator
    with warnings.catch_warnings():
        warnings.simplefilter("error", FutureWarning)
        _ = if_delegate_has_method(delegate="predict")

    # Only when calling it
    with pytest.warns(FutureWarning, match="if_delegate_has_method was deprecated"):
        hasattr(MetaEst(HasPredict()), "predict")


def test_available_if_methods_can_be_pickled():
    """Check that available_if methods can be pickled.

    Non-regression test for #21344.
    """
    return_value = 10
    est = AvailableParameterEstimator(available=True, return_value=return_value)
    pickled_bytes = pickle.dumps(est.available_func)
    unpickled_func = pickle.loads(pickled_bytes)
    assert unpickled_func() == return_value
