U
    (dk                     @   s  d dl mZmZ d dlZd dlmZ d dlm  mZ d dlm	Z	 d dl
mZ d dlmZ d dlmZ ddlmZ dd	lmZ d
dlmZmZ d
dlmZ ddlmZmZmZ dZG dd dejZG dd dejZG dd dejZ G dd dejZ!G dd dejZ"dd Z#G dd dejZ$G dd dejZ%G d d! d!ejZ&G d"d# d#ejZ'G d$d% d%ejZ(G d&d' d'ejZ)d(d)iZ*G d*d+ d+eZ+G d,d- d-eZ,dd.d/d0d1Z-ed2e+j.fd3dd4d/ee+ e)d5d6d7Z/ed2e,j0fd3dd4d/ee, e)d5d8d9Z1dS ):    )ListOptionalN)Tensor)BatchNorm2d)InstanceNorm2d)Conv2dNormActivation   )OpticalFlow)_log_api_usage_once   )WeightsWeightsEnum)handle_legacy_interface   )grid_samplemake_coords_gridupsample_flow)RAFT
raft_large
raft_smallRaft_Large_WeightsRaft_Small_Weightsc                       s.   e Zd ZdZdd fdd
Zdd Z  ZS )ResidualBlockz<Slightly modified Residual block with extra relu and biases.r   stridec             	      sr   t    t|||d|dd| _t|||ddd| _|dkrHt | _nt|||d|dd d| _tjdd| _	d S )Nr   T
norm_layerkernel_sizer   biasr   r   r   r   r   r   r   r   Zactivation_layerZinplace)
super__init__r   convnormrelu1convnormrelu2nnIdentity
downsampleReLUreluselfin_channelsout_channelsr   r   	__class__ H/tmp/pip-unpacked-wheel-vx7f76es/torchvision/models/optical_flow/raft.pyr#      s8    
         
zResidualBlock.__init__c                 C   s0   |}|  |}| |}| |}| || S N)r$   r%   r(   r*   r,   xyr1   r1   r2   forward=   s
    


zResidualBlock.forward__name__
__module____qualname____doc__r#   r7   __classcell__r1   r1   r/   r2   r      s   r   c                       s.   e Zd ZdZdd fdd
Zdd Z  ZS )BottleneckBlockz:Slightly modified BottleNeck block (extra relu and biases)r   r   c             	      s   t    t||d |ddd| _t|d |d |d|dd| _t|d ||ddd| _tjdd| _|dkrzt	 | _
nt|||d|dd d| _
d S )	N   r   Tr   r   r   r!   r    )r"   r#   r   r$   r%   convnormrelu3r&   r)   r*   r'   r(   r+   r/   r1   r2   r#   J   sF    
             zBottleneckBlock.__init__c                 C   s:   |}|  |}| |}| |}| |}| || S r3   )r$   r%   r@   r(   r*   r4   r1   r1   r2   r7   f   s    



zBottleneckBlock.forwardr8   r1   r1   r/   r2   r>   G   s   r>   c                       s<   e Zd ZdZedejd fdd
Zdd Zdd	 Z	  Z
S )
FeatureEncoderzThe feature encoder, used both as the actual feature encoder, and as the context encoder.

    It must downsample its input by 8.
    @   rC   `         blocklayersr   c                   s2  t    t|dkr(tdt| td|d |dddd| _| j||d |d	 |d	d
| _| j||d	 |d |dd
| _| j||d |d |dd
| _	t
j|d |d d	d| _|  D ]n}t|t
jrt
jj|jddd qt|t
jt
jfr|jd k	rt
j|jd	 |jd k	rt
j|jd qd S )N   z0The expected number of layers is 5, instead got r   r      r   Tr   r   )r   first_strider?   )r   Zfan_outr*   )modeZnonlinearity)r"   r#   len
ValueErrorr   convnormrelu_make_2_blockslayer1layer2layer3r&   Conv2dconvmodules
isinstanceinitZkaiming_normal_Zweightr   r   Z	constant_r   )r,   rH   rI   r   mr/   r1   r2   r#   w   s.    
     
zFeatureEncoder.__init__c                 C   s,   |||||d}||||dd}t ||S )N)r   r   r   )r&   Z
Sequential)r,   rH   r-   r.   r   rL   Zblock1Zblock2r1   r1   r2   rQ      s    zFeatureEncoder._make_2_blocksc                 C   s6   |  |}| |}| |}| |}| |}|S r3   )rP   rR   rS   rT   rV   r,   r5   r1   r1   r2   r7      s    




zFeatureEncoder.forward)r9   r:   r;   r<   r   r&   r   r#   rQ   r7   r=   r1   r1   r/   r2   rA   q   s   rA   c                       s2   e Zd ZdZdddd fdd
Zdd	 Z  ZS )
MotionEncoderzThe motion encoder, part of the update block.

    Takes the current predicted flow and the correlation features as input and returns an encoded version of these.
    rF      rE   rC   rE   )corr_layersflow_layersr.   c                   s   t    t|dkr(tdt| t|dkrFtdt| t||d d dd| _t|dkrt|d |d d dd| _n
t | _td|d d d	d| _	t|d |d d dd| _
t|d
 |d
  |d d dd| _|| _d S )Nr   z5The expected number of flow_layers is 2, instead got r   r   z8The number of corr_layers should be 1 or 2, instead got r   r   r   r   r   rK   )r"   r#   rN   rO   r   	convcorr1	convcorr2r&   r'   	convflow1	convflow2rV   r.   )r,   in_channels_corrr`   ra   r.   r/   r1   r2   r#      s$    

   zMotionEncoder.__init__c                 C   sZ   |  |}| |}|}| |}| |}tj||gdd}| |}tj||gddS Nr   dim)re   rf   rg   rh   torchcatrV   )r,   flowcorr_featurescorrZ	flow_origZ	corr_flowr1   r1   r2   r7      s    




zMotionEncoder.forwardr8   r1   r1   r/   r2   r\      s   r\   c                       s(   e Zd ZdZ fddZdd Z  ZS )ConvGRUzConvolutional Gru unit.c                   sV   t    tj|| |||d| _tj|| |||d| _tj|| |||d| _d S )Nr   padding)r"   r#   r&   rU   convzconvrconvqr,   
input_sizehidden_sizer   rt   r/   r1   r2   r#      s    
zConvGRU.__init__c                 C   sl   t j||gdd}t | |}t | |}t | t j|| |gdd}d| | ||  }|S rj   )rm   rn   Zsigmoidru   rv   tanhrw   )r,   hr5   hxzrqr1   r1   r2   r7      s    "zConvGRU.forwardr8   r1   r1   r/   r2   rr      s   rr   c                 C   s   | S r3   r1   )r|   _r1   r1   r2   _pass_through_h   s    r   c                       s0   e Zd ZdZddd fdd
Zdd Z  ZS )	RecurrentBlockzRecurrent block, part of the update block.

    Takes the current hidden state and the concatenation of (motion encoder output, context) as input.
    Returns an updated hidden state.
    )r   rJ   )rJ   r   )r   r   )r   r   rs   c                   s   t    t|t|kr6tdt| dt| t|dkrTtdt| t|||d |d d| _t|dkrt|||d |d d| _nt| _|| _d S )	NzSkernel_size should have the same length as padding, instead got len(kernel_size) = z and len(padding) = rb   z.kernel_size should either 1 or 2, instead got r   ry   rz   r   rt   r   r   )	r"   r#   rN   rO   rr   convgru1convgru2r   rz   rx   r/   r1   r2   r#      s,    
      
zRecurrentBlock.__init__c                 C   s   |  ||}| ||}|S r3   )r   r   )r,   r|   r5   r1   r1   r2   r7     s    zRecurrentBlock.forwardr8   r1   r1   r/   r2   r      s   r   c                       s(   e Zd ZdZ fddZdd Z  ZS )FlowHeadzFlow head, part of the update block.

    Takes the hidden state of the recurrent unit as input, and outputs the predicted "delta flow".
    c                   sD   t    tj||ddd| _tj|dddd| _tjdd| _d S )Nr   r   rt   r   Tr!   )r"   r#   r&   rU   conv1conv2r)   r*   )r,   r-   rz   r/   r1   r2   r#     s    
zFlowHead.__init__c                 C   s   |  | | |S r3   )r   r*   r   r[   r1   r1   r2   r7     s    zFlowHead.forwardr8   r1   r1   r/   r2   r     s   r   c                       s(   e Zd ZdZ fddZdd Z  ZS )UpdateBlockzThe update block which contains the motion encoder, the recurrent block, and the flow head.

    It must expose a ``hidden_state_size`` attribute which is the hidden state size of its recurrent block.
    c                   s(   t    || _|| _|| _|j| _d S r3   )r"   r#   motion_encoderrecurrent_block	flow_headrz   hidden_state_size)r,   r   r   r   r/   r1   r2   r#     s
    
zUpdateBlock.__init__c                 C   s<   |  ||}tj||gdd}| ||}| |}||fS rj   )r   rm   rn   r   r   )r,   hidden_statecontextrp   ro   Zmotion_featuresr5   
delta_flowr1   r1   r2   r7   %  s
    
zUpdateBlock.forwardr8   r1   r1   r/   r2   r     s   r   c                       s.   e Zd ZdZdd fdd
Zdd Z  ZS )MaskPredictorzMask predictor to be used when upsampling the predicted flow.

    It takes the hidden state of the recurrent unit as input and outputs the mask.
    This is not used in the raft-small model.
          ?)
multiplierc                   s:   t    t||d dd| _tj|dddd| _|| _d S )Nr   rc   i@  r   r   r   )r"   r#   r   convrelur&   rU   rV   r   )r,   r-   rz   r   r/   r1   r2   r#   5  s    
zMaskPredictor.__init__c                 C   s   |  |}| |}| j| S r3   )r   rV   r   r[   r1   r1   r2   r7   B  s    

zMaskPredictor.forwardr8   r1   r1   r/   r2   r   .  s   r   c                       sH   e Zd ZdZdddeed fddZdd Zdd	 Zd
d Z  Z	S )	CorrBlocka  The correlation block.

    Creates a correlation pyramid with ``num_levels`` levels from the outputs of the feature encoder,
    and then indexes from this pyramid to create correlation features.
    The "indexing" of a given centroid pixel x' is done by concatenating its surrounding neighbors that
    are within a ``radius``, according to the infinity norm (see paper section 3.2).
    Note: typo in the paper, it should be infinity norm, not 1-norm.
    r?   
num_levelsradiusc                   s>   t    || _|| _tdg| _|d| d d  | _d S )Nr   r   r   )r"   r#   r   r   rm   tensorcorr_pyramidr.   )r,   r   r   r/   r1   r2   r#   R  s
    
zCorrBlock.__init__c           	      C   s   |j |j kr&td|j  d|j  d| ||}|j \}}}}}}||| | |||}|g| _t| jd D ] }tj|ddd}| j	| qrdS )aL  Build the correlation pyramid from two feature maps.

        The correlation volume is first computed as the dot product of each pair (pixel_in_fmap1, pixel_in_fmap2)
        The last 2 dimensions of the correlation volume are then pooled num_levels times at different resolutions
        to build the correlation pyramid.
        z;Input feature maps should have the same shape, instead got z (fmap1.shape) != z (fmap2.shape)r   r   )r   r   N)
shaperO   _compute_corr_volumereshaper   ranger   FZ
avg_pool2dappend)	r,   fmap1fmap2corr_volume
batch_sizer|   wnum_channelsr   r1   r1   r2   build_pyramid_  s    zCorrBlock.build_pyramidc                 C   s2  d| j  d }t| j  | j |}t| j  | j |}tjtj||dddd|j}|d||d}|j\}}}}	|	dddd
|| |	 ddd}g }
| jD ]:}|| }t||d	d
d|||	d}|
| |d }qtj|
dd	dddd }|| j||	f}|j|kr.td| d|j |S )z9Return correlation features by indexing from the pyramid.r   r   Zij)Zindexingrd   rk   r   r   TZbilinear)Zalign_cornersrM   z6Output shape of index pyramid is incorrect. Should be z, got )r   rm   ZlinspacestackZmeshgridtodeviceviewr   Zpermuter   r   r   r   rn   
contiguousr.   rO   )r,   centroids_coordsZneighborhood_side_lenZdiZdjdeltar   r   r|   r   Zindexed_pyramidr   Zsampling_coordsZindexed_corr_volumerp   Zexpected_output_shaper1   r1   r2   index_pyramidt  s2    "$
   

zCorrBlock.index_pyramidc                 C   sn   |j \}}}}||||| }||||| }t|dd|}||||d||}|tt| S )Nr   r   )r   r   rm   matmulZ	transposesqrtr   )r,   r   r   r   r   r|   r   rq   r1   r1   r2   r     s    zCorrBlock._compute_corr_volume)
r9   r:   r;   r<   intr#   r   r   r   r=   r1   r1   r/   r2   r   H  s
   	r   c                       s2   e Zd Zdd fdd
Zd	edddZ  ZS )
r   N)mask_predictorc                   sH   t    t|  || _|| _|| _|| _|| _t| jdsDt	ddS )a	  RAFT model from
        `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

        args:
            feature_encoder (nn.Module): The feature encoder. It must downsample the input by 8.
                Its input is the concatenation of ``image1`` and ``image2``.
            context_encoder (nn.Module): The context encoder. It must downsample the input by 8.
                Its input is ``image1``. As in the original implementation, its output will be split into 2 parts:

                - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
                - one part will be used to initialize the hidden state of the of the recurrent unit of
                  the ``update_block``

                These 2 parts are split according to the ``hidden_state_size`` of the ``update_block``, so the output
                of the ``context_encoder`` must be strictly greater than ``hidden_state_size``.

            corr_block (nn.Module): The correlation block, which creates a correlation pyramid from the output of the
                ``feature_encoder``, and then indexes from this pyramid to create correlation features. It must expose
                2 methods:

                - a ``build_pyramid`` method that takes ``feature_map_1`` and ``feature_map_2`` as input (these are the
                  output of the ``feature_encoder``).
                - a ``index_pyramid`` method that takes the coordinates of the centroid pixels as input, and returns
                  the correlation features. See paper section 3.2.

                It must expose an ``out_channels`` attribute.

            update_block (nn.Module): The update block, which contains the motion encoder, the recurrent unit, and the
                flow head. It takes as input the hidden state of its recurrent unit, the context, the correlation
                features, and the current predicted flow. It outputs an updated hidden state, and the ``delta_flow``
                prediction (see paper appendix A). It must expose a ``hidden_state_size`` attribute.
            mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
                The output channel must be 8 * 8 * 9 - see paper section 3.3, and Appendix B.
                If ``None`` (default), the flow is upsampled using interpolation.
        r   zIThe update_block parameter should expose a 'hidden_state_size' attribute.N)
r"   r#   r
   feature_encodercontext_encoder
corr_blockupdate_blockr   hasattrrO   )r,   r   r   r   r   r   r/   r1   r2   r#     s    $
zRAFT.__init__   )num_flow_updatesc              	   C   s.  |j \}}}}||f|j dd  krHtd| d| d|j dd   |d dksv|d dkrvtd| d| d	| tj||gdd
}tj|ddd\}	}
|	j dd  |d |d fkrtd| j|	|
 | |}|j dd  |d |d fkrtd| j	j
}|j d | }|dkrDtd|j d  d| dtj|||gdd
\}}t|}t|}t||d |d |	j}t||d |d |	j}g }t|D ]t}| }| jj|d}|| }| 	||||\}}|| }| jd krd n| |}t|| |d}|| q|S )Nz6input images should have the same shape, instead got (z, z) !=    r   z9input image H and W should be divisible by 8, insted got z	 (h) and z (w)rk   r   )chunksrl   z2The feature encoder should downsample H and W by 8z2The context encoder should downsample H and W by 8r   zThe context encoder outputs zA channels, but it should have at strictly more than hidden_state=z	 channels)r   )ro   up_mask)r   rO   r   rm   rn   chunkr   r   r   r   r   splitr{   r   r*   r   r   r   r   detachr   r   r   r   )r,   Zimage1Zimage2r   r   r   r|   r   Zfmapsr   r   Zcontext_outr   out_channels_contextr   r   Zcoords0Zcoords1Zflow_predictionsrp   ro   r   r   Zupsampled_flowr1   r1   r2   r7     sF    $
 


zRAFT.forward)r   )r9   r:   r;   r#   r   r7   r=   r1   r1   r/   r2   r     s   1r   Zmin_size)rE   rE   c                   @   s  e Zd ZdZedeeddddiddidd	d
ddddZedeeddddiddiddd
ddddZedeeddddiddiddddZ	edeeddddiddiddddZ
edeeddd d!d"iid#ddZed$eeddd d!d%iid&ddZe
Zd'S )(r     The metrics reported here are as follows.

    ``epe`` is the "end-point-error" and indicates how far (in pixels) the
    predicted flow is from its true value. This is averaged over all pixels
    of all images. ``per_image_epe`` is similar, but the average is different:
    the epe is first computed on each image independently, and then averaged
    over all images. This corresponds to "Fl-epe" (sometimes written "F1-epe")
    in the original paper, and it's only used on Kitti. ``fl-all`` is also a
    Kitti-specific metric, defined by the author of the dataset and used for the
    Kitti leaderboard. It corresponds to the average of pixels whose epe is
    either <3px, or <5% of flow's 2-norm.
    zBhttps://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pthi@9P $https://github.com/princeton-vl/RAFTepeg?߾?g{P@gu@gޓZs1@Zper_image_epefl_allzSintel-Train-CleanpasszSintel-Train-FinalpasszKitti-TrainThese weights were ported from the original paper. They
            are trained on :class:`~torchvision.datasets.FlyingChairs` +
            :class:`~torchvision.datasets.FlyingThings3D`.Z
num_paramsZrecipeZ_metricsZ_docsurlZ
transformsmetazBhttps://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pthChttps://github.com/pytorch/vision/tree/main/references/optical_flowgH}?g&S@g_L@gea0@These weights were trained from scratch on
            :class:`~torchvision.datasets.FlyingChairs` +
            :class:`~torchvision.datasets.FlyingThings3D`.zGhttps://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pthg
ףp=
?gq=
ףp	@)zSintel-Test-CleanpasszSintel-Test-Finalpassa0  
                These weights were ported from the original paper. They are
                trained on :class:`~torchvision.datasets.FlyingChairs` +
                :class:`~torchvision.datasets.FlyingThings3D` and fine-tuned on
                Sintel. The Sintel fine-tuning step is a combination of
                :class:`~torchvision.datasets.Sintel`,
                :class:`~torchvision.datasets.KittiFlow`,
                :class:`~torchvision.datasets.HD1K`, and
                :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
            zGhttps://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pthgv?gK7@a/  
                These weights were trained from scratch. They are
                pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
                :class:`~torchvision.datasets.FlyingThings3D` and then
                fine-tuned on Sintel. The Sintel fine-tuning step is a
                combination of :class:`~torchvision.datasets.Sintel`,
                :class:`~torchvision.datasets.KittiFlow`,
                :class:`~torchvision.datasets.HD1K`, and
                :class:`~torchvision.datasets.FlyingThings3D` (clean pass).
            zIhttps://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pthz
Kitti-Testr   gffffff@a  
                These weights were ported from the original paper. They are
                pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
                :class:`~torchvision.datasets.FlyingThings3D`,
                fine-tuned on Sintel, and then fine-tuned on
                :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
                step was described above.
            zIhttps://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pthg(\@a  
                These weights were trained from scratch. They are
                pre-trained on :class:`~torchvision.datasets.FlyingChairs` +
                :class:`~torchvision.datasets.FlyingThings3D`,
                fine-tuned on Sintel, and then fine-tuned on
                :class:`~torchvision.datasets.KittiFlow`. The Sintel fine-tuning
                step was described above.
            N)r9   r:   r;   r<   r   r	   _COMMON_METAC_T_V1C_T_V2ZC_T_SKHT_V1C_T_SKHT_V2ZC_T_SKHT_K_V1ZC_T_SKHT_K_V2DEFAULTr1   r1   r1   r2   r     s     r   c                   @   sx   e Zd ZdZedeeddddiddidd	d
ddddZedeeddddiddiddd
ddddZeZ	dS )r   r   zBhttps://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pthi r   r   gQ @gZd;
@go@g3G9@r   r   r   r   r   zBhttps://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pthr   gHPs?gC
@g1%d@g"lxz<9@r   N)
r9   r:   r;   r<   r   r	   r   r   r   r   r1   r1   r1   r2   r     s8   r   F)weightsprogressc                 K   s  | dd pt|||d}| dd p2t|||d}| dd pJt||	d}| dd }|d krt|j|
||d}|d | }t|j| |||d	}t||d
}t|||d}| dd }|d kr|rt|ddd}t	f |||||d|}| d k	r
|
| j|d |S )Nr   rG   r   r   r   r   )ri   r`   ra   r.   rd   r   )r-   rz   )r   r   r   r   rF   r   )r-   rz   r   )r   r   r   r   r   )r   )poprA   r   r\   r.   r   r   r   r   r   Zload_state_dictZget_state_dict)r   r   feature_encoder_layersfeature_encoder_blockfeature_encoder_norm_layercontext_encoder_layerscontext_encoder_blockcontext_encoder_norm_layercorr_block_num_levelscorr_block_radiusmotion_encoder_corr_layersmotion_encoder_flow_layersmotion_encoder_out_channels!recurrent_block_hidden_state_sizerecurrent_block_kernel_sizerecurrent_block_paddingflow_head_hidden_sizeuse_mask_predictorkwargsr   r   r   r   r   r   r   r   r   modelr1   r1   r2   _raft  s^        	
r   Z
pretrained)r   T)r   returnc                 K   s>   t | } tf | |dttdttdddddddddd	d
|S )a  RAFT model from
    `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

    Please see the example below for a tutorial on how to use this model.

    Args:
        weights(:class:`~torchvision.models.optical_flow.Raft_Large_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.optical_flow.Raft_Large_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.optical_flow.Raft_Large_Weights
        :members:
    rB   r?   r]   r_   rE   r   r   rF   Tr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r   verifyr   r   r   r   r   r   r   r1   r1   r2   r   !  s.    
r   c                 K   s>   t | } tf | |dttdtddddddd	d
dddd|S )a  RAFT "small" model from
    `RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`__.

    Please see the example below for a tutorial on how to use this model.

    Args:
        weights(:class:`~torchvision.models.optical_flow.Raft_Small_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.optical_flow.Raft_Small_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.optical_flow.RAFT``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/optical_flow/raft.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.optical_flow.Raft_Small_Weights
        :members:
    )    r   rC   rD   rE   )r   r   rC   rD      Nr?   r   )rD   )rC   r   R   rD   )r   )r   rE   Fr   )r   r   r   r>   r   r   r1   r1   r2   r   X  s.    
r   )2typingr   r   rm   Ztorch.nnr&   Ztorch.nn.functionalZ
functionalr   r   Ztorch.nn.modules.batchnormr   Ztorch.nn.modules.instancenormr   Ztorchvision.opsr   Ztransforms._presetsr	   utilsr
   Z_apir   r   _utilsr   r   r   r   __all__Moduler   r>   rA   r\   rr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r1   r1   r1   r2   <module>   sL   	,*1+#Tg  7S6