U
    3d                  
   @   s|  d dl Z d dlZd dlmZ d dlmZ d dlmZm	Z	 d dl
mZ d dlmZmZ d dlmZmZmZ d dlmZ d d	lmZ d d
lmZmZmZ d dlmZ d dlmZmZ d dlm Z  d dl!m"Z"m#Z# d dl$m%Z%m&Z& d dl'm(Z( d dl)m*Z* d dl+m,Z,m-Z- d dlm.Z. d dl/m0Z0 d dl1m2Z2 d dl3m4Z4 d dl5m6Z6 d dl7m8Z8m9Z9m:Z:m;Z;m<Z<m=Z= d dl>m?Z? d dlm@Z@ dZAe jBdddd ZCe jDEd d!d"ge jDEd#d$d%gd&d' ZFd(d) ZGe jDEd#d$d%gd*d+ ZHe jDEd d!d"ge jDEd#d$d%gd,d- ZIe jDEd d!d"ge jDEd#d$d%gd.d/ ZJe jDEd d!d"ge jDEd#d$d%ge jDEd0eKd1d2d3 ZLd4d5 ZMd6d7 ZNe jDEd d!d"gd8d9 ZOd:d; ZPd<d= ZQd>d? ZRe jDEd#d$d%gd@dA ZSe jDEd#d$d%gdBdC ZTe jDEd#d$d%gdDdE ZUe jDEdFejVWdGXdHdId1ejVWdGXdHdId1dJgdKdL ZYe jBdMdN ZZe jBdOdP Z[dQdR Z\e jDEdSe ]e*dTdUd1e ]e*dTdUdVgdWdX Z^dYdZ Z_d[d\ Z`e jBddd]d^ Zae jBddd_d` Zbdadb Zce jDEdcdddegdfdg Zde jDEdhdIdige jDEdjdkdlgdmdn Zedodp Zfe jDEdqdrdsgdtdu Zgdvdw Zhe jDEdcdddegdxdy Zidzd{ Zje jDEd|ekelgd}d~ Zme jDEd|ekelgdd Zne jDEddddgdd Zoe jDEd d!d"ge jDEd#d$d%gdd Zpe jDEdddgdd Zqe jDEddgeA ereAgdd Zsdd Ztdd Zue jDEd d!d"ge jDEd#d$d%gdd Zvdd Zwdd ZxdS )    N)assert_allclose)sparse)BaseEstimatorclone)DummyClassifier)LeaveOneOuttrain_test_split)assert_array_almost_equalassert_almost_equalassert_array_equal)softmax)NotFittedError)make_classification
make_blobs	load_iris)LabelEncoder)KFoldcross_val_predict)MultinomialNB)RandomForestClassifierVotingClassifier)LogisticRegressionLinearRegression)DecisionTreeClassifier)	LinearSVC)Pipelinemake_pipeline)StandardScaler)IsotonicRegression)DictVectorizer)SimpleImputer)brier_score_loss)_CalibratedClassifier_SigmoidCalibration_sigmoid_calibrationCalibratedClassifierCVCalibrationDisplaycalibration_curve)CheckingClassifier)_convert_container   module)Zscopec                  C   s   t tddd\} }| |fS )N   *   	n_samples
n_featuresrandom_state)r   	N_SAMPLES)Xy r5   B/tmp/pip-unpacked-wheel-zrfo1fqw/sklearn/tests/test_calibration.pydata4   s    r7   methodsigmoidisotonicensembleTFc              	   C   s(  t d }| \}}tjjddj|jd}|| 8 }|d | |d | |d |   }}}	||d  ||d   }
}tddj|||	d}|	|
d d df }t
||jd |d	}tt ||| W 5 Q R X ||
ft|t|
ffD ],\}}t
||d
|d}|j|||	d |	|d d df }t||t||ksNt|j||d |	d |	|d d df }t|| |j|d| d |	d |	|d d df }t|| |j||d d |	d |	|d d df }|dkrt|d|  qt||t|d d |kstqd S )N   r-   seedsizeTZforce_alphasample_weight   cvr;      r8   rF   r;   r9   )r2   nprandomRandomStateuniformr@   minr   fitpredict_probar%   pytestraises
ValueErrorr   
csr_matrixr!   AssertionErrorr	   )r7   r8   r;   r/   r3   r4   rC   X_trainy_trainsw_trainX_testy_testclfprob_pos_clfcal_clfZthis_X_trainthis_X_testprob_pos_cal_clfZprob_pos_cal_clf_relabeledr5   r5   r6   test_calibration:   sF    (
 




 r_   c                 C   s<   | \}}t dd}||| |jd j}t|ts8td S )Nr<   rF   r   )r%   rN   calibrated_classifiers_	estimator
isinstancer   rT   )r7   r3   r4   	calib_clfZbase_estr5   r5   r6   "test_calibration_default_estimatorw   s
    
re   c                 C   sp   | \}}d}t |d}t||d}t|jt s2t|jj|ksBt||| |rV|nd}t|j|ksltd S )NrG   )n_splitsrE   rD   )	r   r%   rc   rF   rT   rf   rN   lenra   )r7   r;   r3   r4   splitsZkfoldrd   Zexpected_n_clfr5   r5   r6   test_calibration_cv_splitter   s    
ri   c                 C   s   t d }| \}}tjjddjt|d}|d | |d | |d |   }}}	||d  }
tdd}t|||d}|j|||	d |	|
}||| |	|
}tj
|| }|dkstd S )	Nr<   r-   r=   r?   r1   )r8   r;   rB   皙?)r2   rI   rJ   rK   rL   rg   r   r%   rN   rO   ZlinalgZnormrT   )r7   r8   r;   r/   r3   r4   rC   rU   rV   rW   rX   rb   calibrated_clfZprobs_with_swZprobs_without_swZdiffr5   r5   r6   test_sample_weight   s    (


rm   c                 C   s   | \}}t ||dd\}}}}tdd}	t|	|d|d}
|
|| |
|}t|	|d|d}||| ||}t|| dS )zTest parallel calibrationr-   rj   r<   )r8   Zn_jobsr;   rD   N)r   r   r%   rN   rO   r   )r7   r8   r;   r3   r4   rU   rX   rV   rY   rb   Zcal_clf_parallelZprobs_parallelZcal_clf_sequentialZprobs_sequentialr5   r5   r6   test_parallel_execution   s(    
   
   
rn   r>   r<   c                 C   s  dd }t dd}tdd|ddd	\}}d
||d
k< t|jd }|d d d
 |d d d
  }}	|dd d
 |dd d
  }
}|||	 t|| d|d}|||	 ||
}ttj	|ddt
t|
 d||
|  k rdk sn t||
|d||
| kst||t||
|d}||||d}|d| k sLttddd}|||	 ||
}||||d}t|| d|d}|||	 ||
}||||d}|d| k std S )Nc                 S   s*   t ||  }t || d |jd  S )Nr<   r   )rI   Zeyesumshape)y_trueZ
proba_pred	n_classesZY_onehotr5   r5   r6   multiclass_brier   s    z5test_calibration_multiclass.<locals>.multiclass_brier   rj   i  d   
         .@r/   r0   r1   ZcentersZcluster_stdr<   r   rD   rG   rH   Zaxis?gffffff?)rr   g?   r-   )n_estimatorsr1   )r   r   rI   uniquerp   rN   r%   rO   r   ro   onesrg   ZscorerT   r   decision_functionr   )r8   r;   r>   rs   rZ   r3   r4   rr   rU   rV   rX   rY   r\   probasZuncalibrated_brierZcalibrated_brierZ	clf_probsZcal_clf_probsr5   r5   r6   test_calibration_multiclass   sH    
    

""  

r   c                  C   sh   G dd d} t dddddd\}}t ||}|  }t||g|jd}||}t|d	|j  d S )
Nc                   @   s   e Zd Zdd ZdS )z9test_calibration_zero_probability.<locals>.ZeroCalibratorc                 S   s   t |jd S )Nr   )rI   zerosrp   selfr3   r5   r5   r6   predict
  s    zAtest_calibration_zero_probability.<locals>.ZeroCalibrator.predictN)__name__
__module____qualname__r   r5   r5   r5   r6   ZeroCalibrator  s   r   2   rv   rt   rw   rx   )rb   Zcalibratorsclasses      ?)r   r   rN   r"   classes_rO   r   Z
n_classes_)r   r3   r4   rZ   
calibratorr\   r   r5   r5   r6   !test_calibration_zero_probability  s"        
  
r   c               
   C   s  d} t d|  ddd\}}tjjddj|jd}|| 8 }|d|  |d|  |d|    }}}|| d	|   || d	|   || d	|     }}}	|d	|  d |d	|  d  }
}td
d}t|dd}t	
t ||| W 5 Q R X |||| ||
dddf }||
ft|t|
ffD ]\}}dD ]}t||dd}|	dfD ]v}|j|||d ||}||}|dddf }t|tddgtj|dd  t||t||ksZtqZq@q4dS )z*Test calibration for prefitted classifiersr      r,   r-   r.   r=   r?   Nr<   TrA   prefitr`   rD   )r:   r9   )r8   rF   rB   r   ry   )r   rI   rJ   rK   rL   r@   rM   r   r%   rP   rQ   r   rN   rO   r   rS   r   r   arrayZargmaxr!   rT   )r/   r3   r4   rC   rU   rV   rW   ZX_calibZy_calibZsw_calibrX   rY   rZ   Z	unfit_clfr[   Zthis_X_calibr]   r8   r\   swy_proby_predr^   r5   r5   r6   test_calibration_prefit  s>    (
"


"
 r   c                 C   s   | \}}t dd}t||ddd}||| ||}t|||ddd}|dkr^td	d
}nt }||| ||| ||}	||	}
t	|d d df |
 d S )Nrt   rj   r   FrH   r   )rF   r8   r:   Zclip)Zout_of_boundsrD   )
r   r%   rN   rO   r   r   r#   r   r   r   )r7   r8   r3   r4   rZ   r\   Z
cal_probasZunbiased_predsr   Zclf_dfZmanual_probasr5   r5   r6   test_calibration_ensemble_falseK  s    



r   c               	   C   s   t dddg} t dddg}t ddg}t|t| |d ddt |d	 |  |d    }t | || }t||d
 t	t
 t t | | f| W 5 Q R X dS )z0Test calibration values with Platt sigmoid modelrG   r   rD   gj=ɿgY90(?r   r   r,   N)rI   r   r	   r$   expr#   rN   r   rP   rQ   rR   vstack)ZexFZexYZAB_lin_libsvmZlin_probZsk_probr5   r5   r6   test_sigmoid_calibrationd  s    "r   c               	   C   sL  t ddddddg} t ddddddg}t| |d	d
\}}t|t|ksRtt|d	ksbtt|ddg t|ddg tt tdgdg W 5 Q R X t ddddddg}t ddddddg}t||d	dd\}}t|t|kstt|d	kstt|ddg t|ddg tt t||dd W 5 Q R X dS )z Check calibration_curve functionr   rD           rk   皙?皙??r   r<   n_binsg      ?quantiler   strategygUUUUUU?Z
percentile)r   N)	rI   r   r'   rg   rT   r
   rP   rQ   rR   )rq   r   	prob_true	prob_predZy_true2Zy_pred2Zprob_true_quantileZprob_pred_quantiler5   r5   r6   test_calibration_curveu  s.       
r   c               	   C   s   t ddddddg} t ddddddg}tt t| |d	d
d W 5 Q R X ttD t| |d	 d	dd\}}t| |d	d\}}t|| t|| W 5 Q R X dS )z6Tests the `normalize` parameter of `calibration_curve`r   rD   r   rk   r   r   r   r   r<   F)r   	normalizeTr   N)rI   r   rP   warnsFutureWarningr'   r
   )rq   r   Zprob_true_unnormalizedZprob_pred_unnormalizedr   r   r5   r5   r6   .test_calibration_curve_with_unnormalized_proba  s       

r   c                 C   sf   t dddddd\}}tj|d< tdt fdtd	d
fg}t|dd| d}||| || dS )z$Test that calibration can accept nanrv   r<   r   r-   )r/   r0   Zn_informativeZn_redundantr1   r   r   ZimputerrfrD   )r|   r:   )rF   r8   r;   N)	r   rI   nanr   r    r   r%   rN   r   )r;   r3   r4   rZ   Zclf_cr5   r5   r6   test_calibration_nan_imputer  s        

r   c                 C   sl   d}t dd|d\}}tddd}t|dt | d	}||| ||}t|jd
dt	|j
d  d S )Nr<   rv   rG   )r/   r0   rr   r   rt   Cr1   r9   rH   rD   ry   r   )r   r   r%   r   rN   rO   r	   ro   rI   r~   rp   )r;   Znum_classesr3   r4   rZ   Zclf_probZprobsr5   r5   r6   test_calibration_prob_sum  s       
r   c                 C   s   t jdd}t d}tddd}t|dt | d}||| t|j	D ]\}}|
|}| rt|d d |f t t| t |d d d |f dkstt |d d |d	 d f dkstqLt |d	|jd  sLtqLd S )
Nrv   rG   r   rt   r   r9   rH   r   rD   )rI   rJ   randnZaranger   r%   r   rN   	enumeratera   rO   r   r   rg   allrT   Zallcloserp   )r;   r3   r4   rZ   r\   icalibrated_classifierZprobar5   r5   r6   test_calibration_less_classes  s"    
   
 "(r   r3   r-      rG   r,   c                 C   sL   dddddddddddddddg}G dd dt }t| }|| | dS )z;Test that calibration accepts n-dimensional arrays as inputrD   r   c                   @   s    e Zd ZdZdd Zdd ZdS )z>test_calibration_accepts_ndarray.<locals>.MockTensorClassifierz*A toy estimator that accepts tensor inputsc                 S   s   t || _| S )N)rI   r}   r   )r   r3   r4   r5   r5   r6   rN     s    zBtest_calibration_accepts_ndarray.<locals>.MockTensorClassifier.fitc                 S   s   | |jd djddS )Nr   r   rD   ry   )Zreshaperp   ro   r   r5   r5   r6   r     s    zPtest_calibration_accepts_ndarray.<locals>.MockTensorClassifier.decision_functionN)r   r   r   __doc__rN   r   r5   r5   r5   r6   MockTensorClassifier  s   r   N)r   r%   rN   )r3   r4   r   rl   r5   r5   r6    test_calibration_accepts_ndarray  s    	"
r   c                  C   s.   dddddddddg} dddg}| |fS )	NZNYZadult)stateZageZTXVTchildrD   r   r5   )	dict_dataZtext_labelsr5   r5   r6   r     s    
r   c                 C   s,   | \}}t dt fdt fg}|||S )NZ
vectorizerrZ   )r   r   r   rN   )r   r3   r4   Zpipeline_prefitr5   r5   r6   dict_data_pipeline  s
    r   c                 C   sf   | \}}|}t |dd}||| t|j|j t|dr@tt|drNt|| || dS )aR  Test that calibration works in prefit pipeline with transformer

    `X` is not array-like, sparse matrix or dataframe at the start.
    See https://github.com/scikit-learn/scikit-learn/issues/8710

    Also test it can predict without running into validation errors.
    See https://github.com/scikit-learn/scikit-learn/issues/19637
    r   r`   n_features_in_N)r%   rN   r   r   hasattrrT   r   rO   )r   r   r3   r4   rZ   rd   r5   r5   r6   test_calibration_dict_pipeline  s    	
r   zclf, cvrD   r   r   c                 C   s   t ddddd\}}|dkr(| ||} t| |d}||| |dkrht|j| j |j| jkstn.t |j}t|j| |j|jd kstd S )	Nrv   rG   r<   rt   r/   r0   rr   r1   r   r`   rD   )	r   rN   r%   r   r   r   rT   r   rp   )rZ   rF   r3   r4   rd   r   r5   r5   r6   test_calibration_attributes+  s    	r   c               	   C   sp   t ddddd\} }tdd| |}t|dd	}d
}tjt|d" || d d d df | W 5 Q R X d S )Nrv   rG   r<   rt   r   rD   r   r   r`   zAX has 3 features, but LinearSVC is expecting 5 features as input.matchr   )r   r   rN   r%   rP   rQ   rR   )r3   r4   rZ   rd   msgr5   r5   r6   2test_calibration_inconsistent_prefit_n_features_inC  s    r   c                  C   sV   t ddddd\} }tdd tdD d	d
}|| | t|dd}|| | d S )Nrv   rG   r<   rt   r   c                 S   s   g | ]}d t | t fqS )lr)strr   ).0r   r5   r5   r6   
<listcomp>U  s     z5test_calibration_votingclassifier.<locals>.<listcomp>r   Zsoft)Z
estimatorsZvotingr   )rb   rF   )r   r   rangerN   r%   )r3   r4   Zvoterd   r5   r5   r6   !test_calibration_votingclassifierO  s    r   c                   C   s
   t ddS )NTZ
return_X_y)r   r5   r5   r5   r6   	iris_data_  s    r   c                 C   s    | \}}||dk  ||dk  fS )Nr<   r5   )r   r3   r4   r5   r5   r6   iris_data_binaryd  s    r   c           
   	   C   s   |\}}|\}}t  ||}d}tjt|d t||| W 5 Q R X t ||}	d}tjt|d t|	|| W 5 Q R X t }	tt	 t|	|| W 5 Q R X d S )Nz)'estimator' should be a fitted classifierr   z/response method predict_proba is not defined in)
r   rN   rP   rQ   rR   r&   from_estimatorr   r   r   )
pyplotr   r   r3   r4   ZX_binaryZy_binaryregr   rZ   r5   r5   r6   #test_calibration_display_validationj  s    r   constructor_namer   from_predictionsc              	   C   s   |\}}t  }||| ||}|dkrZd}tjt|d t||| W 5 Q R X n*d}tjt|d t|| W 5 Q R X d S )Nr   z"to be a binary classifier, but gotr   z-y should be a 1d array, got an array of shape)	r   rN   rO   rP   rQ   rR   r&   r   r   )r   r   r   r3   r4   rZ   r   r   r5   r5   r6   #test_calibration_display_non_binary}  s    
r   r   rv   r   rL   r   c                 C   sR  |\}}t  ||}tj|||||dd}||d d df }t||||d\}	}
t|j|	 t|j|
 t|j	| |j
dkstdd l}t|j|jjst|j dkstt|j|jjstt|j|jjst|j dkst|j dkstdd	g}|j  }t|t|ks0t|D ]}| |ks4tq4d S )
Nr   )r   r   alpharD   r   r   r   z.Mean predicted probability (Positive class: 1)z)Fraction of positives (Positive class: 1)Perfectly calibrated)r   rN   r&   r   rO   r'   r   r   r   r   estimator_namerT   Z
matplotlibrc   Zline_linesZLine2DZ	get_alphaax_ZaxesZAxesZfigure_figureZFigure
get_xlabel
get_ylabel
get_legend	get_textsrg   get_text)r   r   r   r   r3   r4   r   vizr   r   r   Zmplexpected_legend_labelslegend_labelslabelsr5   r5   r6    test_calibration_display_compute  sB            
r   c           	      C   sz   |\}}t t t }||| t|||}|jdg}|j 	 }t
|t
|ks\t|D ]}| |ks`tq`d S )Nr   )r   r   r   rN   r&   r   r   r   r   r   rg   rT   r   )	r   r   r3   r4   rZ   r   r   r   r   r5   r5   r6   $test_plot_calibration_curve_pipeline  s    
r   zname, expected_label)NZ_line1)my_estr   c           
      C   s   t ddddg}t ddddg}t g }t||||d}|  |d krRg n|g}|d |j  }t|t|kst	|D ]}	|	
 |kst	qd S )Nr   rD   r   r   皙?r   r   )rI   r   r&   plotappendr   r   r   rg   rT   r   )
r   nameZexpected_labelr   r   r   r   r   r   r   r5   r5   r6   'test_calibration_display_default_labels  s    

r   c           	      C   s   t ddddg}t ddddg}t g }d}t||||d}|j|ksPtd}|j|d	 |d
g}|j  }t	|t	|kst|D ]}|
 |kstqd S )Nr   rD   r   r   r   zname oner   zname twor   r   )rI   r   r&   r   rT   r   r   r   r   rg   r   )	r   r   r   r   r   r   r   r   r   r5   r5   r6   )test_calibration_display_label_class_plot  s    
r   c                 C   s  |\}}d}t  ||}||d d df }tt| }| dkrL|||fn||f}	||	d|i}
|
j|kspt|d |
  |dg}|
j	
  }t|t|kst|D ]}| |kstq|d d}|
j|d t|t|kst|D ]}| |kstqd S )	Nzmy hand-crafted namerD   r   r   r   r   Zanother_namer   )r   rN   rO   getattrr&   r   rT   closer   r   r   r   rg   r   )r   r   r   r3   r4   Zclf_namerZ   r   constructorparamsr   r   r   r   r5   r5   r6   ,test_calibration_display_name_multiple_calls  s*    


r   c           	      C   sj   |\}}t  ||}t ||}t|||}tj||||jd}|j d }|ddksftd S )N)ZaxrD   r   )	r   rN   r   r&   r   r   Zget_legend_handles_labelscountrT   )	r   r   r3   r4   r   dtr   Zviz2r   r5   r5   r6   !test_calibration_display_ref_line  s    r   dtype_y_strc              	   C   sh   t jd}t jdgd dgd  | d}|jdd|jd}d	}tjt|d
 t	|| W 5 Q R X dS )zKCheck error message when a `pos_label` is not specified with `str` targets.r-   spamr   eggsr<   dtyper   r?   zy_true takes value in {'eggs', 'spam'} and pos_label is not specified: either make y_true take value in {0, 1} or {-1, 1} or pass pos_label explicitlyr   N)
rI   rJ   rK   r   randintr@   rP   rQ   rR   r'   )r   rngy1y2err_msgr5   r5   r6   *test_calibration_curve_pos_label_error_str  s    r
  c                 C   s   t dddddddddg	}t jddg| d}|| }t dddd	d
ddddg	}t||dd\}}t|ddddg t||ddd\}}t|ddddg t|d| ddd\}}t|ddddg t|d| ddd\}}t|ddddg dS )z8Check the behaviour when passing explicitly `pos_label`.r   rD   r  Zeggr  rk   r   g333333?r   rz   gffffff?r   r   r      r   r   )r   	pos_labelN)rI   r   r'   r   )r   rq   r   Z
y_true_strr   r   _r5   r5   r6    test_calibration_curve_pos_label.  s    r  zpos_label, expected_pos_label)NrD   r   )rD   rD   c                 C   s   |\}}t  ||}tj||||d}||dd|f }t|||d\}	}
t|j|	 t|j|
 t|j	| |j
 d| dkst|j
 d| dkst|jjdg}|j
  }t|t|kst|D ]}| |kstqdS )z?Check the behaviour of `pos_label` in the `CalibrationDisplay`.)r  Nz,Mean predicted probability (Positive class: )z'Fraction of positives (Positive class: r   )r   rN   r&   r   rO   r'   r   r   r   r   r   r   rT   r   	__class__r   r   r   rg   r   )r   r   r  Zexpected_pos_labelr3   r4   r   r   r   r   r   r   r   r   r5   r5   r6   "test_calibration_display_pos_labelC  s*    

r  c                 C   sP  t dd\}}t |}|dd |dd  }}t|d }tj|jd d |jd f|jd}||dddddf< ||dddddf< tj|jd d |jd}||ddd< ||ddd< t }t	|| |dd	}t
|}	|	j|||d
 ||| t|	j|jD ]\}
}t|
jj|jj q|	|}||}t|| dS )zrCheck that passing repeating twice the dataset `X` is equivalent to
    passing a `sample_weight` with a factor 2.Tr   Nru   r<   r   rD   r  r8   r;   rF   rB   )r   r   fit_transformrI   	ones_liker   rp   r  r   r%   r   rN   zipra   r   rb   coef_rO   )r8   r;   r3   r4   rC   ZX_twiceZy_twicerb   calibrated_clf_without_weightscalibrated_clf_with_weightsest_with_weightsest_without_weightsy_pred_with_weightsy_pred_without_weightsr5   r5   r6   ?test_calibrated_classifier_cv_double_sample_weights_equivalenced  s>    $

r  fit_params_typelistr   c                 C   sH   |\}}t || t || d}tddgd}t|}|j||f| dS )zTests that fit_params are passed to the underlying base estimator.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/12384
    )abr   r!  Zexpected_fit_paramsN)r)   r(   r%   rN   )r  r7   r3   r4   
fit_paramsrZ   pc_clfr5   r5   r6    test_calibration_with_fit_params  s    r%  rC   r   c                 C   s.   |\}}t dd}t|}|j||| d dS )zMTests that sample_weight is passed to the underlying base
    estimator.
    T)Zexpected_sample_weightrB   N)r(   r%   rN   )rC   r7   r3   r4   rZ   r$  r5   r5   r6   2test_calibration_with_sample_weight_base_estimator  s    
r&  c              	   C   sZ   | \}}t |}G dd dt}| }t|}tt |j|||d W 5 Q R X dS )zCheck that even if the estimator doesn't support
    sample_weight, fitting with sample_weight still works.

    There should be a warning, since the sample_weight is not passed
    on to the estimator.
    c                       s   e Zd Z fddZ  ZS )zUtest_calibration_without_sample_weight_base_estimator.<locals>.ClfWithoutSampleWeightc                    s   d|kst t j||f|S )NrC   )rT   superrN   )r   r3   r4   r#  r  r5   r6   rN     s    zYtest_calibration_without_sample_weight_base_estimator.<locals>.ClfWithoutSampleWeight.fit)r   r   r   rN   __classcell__r5   r5   r(  r6   ClfWithoutSampleWeight  s   r*  rB   N)rI   r  r(   r%   rP   r   UserWarningrN   )r7   r3   r4   rC   r*  rZ   r$  r5   r5   r6   5test_calibration_without_sample_weight_base_estimator  s    
r,  c              	   C   sh   | \}}d|dd i}t |d}t|}dtt d }tjt|d |j||f| W 5 Q R X dS )z]fit_params having different length than data should raise the
    correct error message.
    r   NrG   r"  z>Found input variables with inconsistent numbers of samples: \[z, 5\]r   )r(   r%   r   r2   rP   rQ   rR   rN   )r7   r3   r4   r#  rZ   r$  r   r5   r5   r6   4test_calibration_with_fit_params_inconsistent_length  s    
r-  c                 C   s  t dd\}}t |}t|dd |dd f}t|dd |dd f}t|}d|ddd< t }t|| |dd	}t	|}|j
|||d
 |
|ddd |ddd  t|j|jD ]\}}	t|jj|	jj q||}
||}t|
| dS )z|Check that passing removing some sample from the dataset `X` is
    equivalent to passing a `sample_weight` with a factor 0.Tr   N(   r   Z   rD   r<   r  rB   )r   r   r  rI   r   ZhstackZ
zeros_liker   r%   r   rN   r  ra   r   rb   r  rO   )r8   r;   r3   r4   rC   rb   r  r  r  r  r  r  r5   r5   r6   >test_calibrated_classifier_cv_zeros_sample_weights_equivalence  s6    
 

r0  c              	   C   s8   t t t d}tjtdd |j|   W 5 Q R X dS )zUCheck that we raise an error is a user set both `base_estimator` and
    `estimator`.)base_estimatorrb   z%Both `base_estimator` and `estimator`r   N)r%   r   rP   rQ   rR   rN   )r7   r   r5   r5   r6   /test_calibrated_classifier_error_base_estimator  s     r2  c              	   C   s8   t t d}d}tjt|d |j|   W 5 Q R X dS )zPCheck that we raise a warning regarding the deprecation of
    `base_estimator`.)r1  z+`base_estimator` was renamed to `estimator`r   N)r%   r   rP   r   r   rN   )r7   r   Zwarn_msgr5   r5   r6   5test_calibrated_classifier_deprecation_base_estimator  s    r3  )yrP   ZnumpyrI   Znumpy.testingr   Zscipyr   Zsklearn.baser   r   Zsklearn.dummyr   Zsklearn.model_selectionr   r   Zsklearn.utils._testingr	   r
   r   Zsklearn.utils.extmathr   Zsklearn.exceptionsr   Zsklearn.datasetsr   r   r   Zsklearn.preprocessingr   r   r   Zsklearn.naive_bayesr   Zsklearn.ensembler   r   Zsklearn.linear_modelr   r   Zsklearn.treer   Zsklearn.svmr   Zsklearn.pipeliner   r   r   Zsklearn.isotonicr   Zsklearn.feature_extractionr   Zsklearn.imputer    Zsklearn.metricsr!   Zsklearn.calibrationr"   r#   r$   r%   r&   r'   Zsklearn.utils._mockingr(   r)   r2   Zfixturer7   markZparametrizer_   re   ri   rm   rn   r   r   r   r   r   r   r   r   r   r   r   rJ   rK   r   r   r   r   r   paramr   r   r   r   r   r   r   r   r   r   r   r   r   r   objectr
  r  r  r  r%  r~   r&  r,  r-  r0  r2  r3  r5   r5   r5   r6   <module>   s    

;

=/
 












) 

"


 0

+