megengine.module.deformable_psroi_pooling 源代码

# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..functional import deformable_psroi_pooling
from .module import Module


[文档]class DeformablePSROIPooling(Module): def __init__( self, no_trans, part_size, pooled_h, pooled_w, sample_per_part, spatial_scale, trans_std: float = 0.1, **kwargs ): super().__init__(**kwargs) self.no_trans = no_trans self.part_size = part_size self.pooled_h = pooled_h self.pooled_w = pooled_w self.sample_per_part = sample_per_part self.spatial_scale = spatial_scale self.trans_std = trans_std
[文档] def forward(self, inp, rois, trans): return deformable_psroi_pooling( inp, rois, trans, self.no_trans, self.part_size, self.pooled_h, self.pooled_w, self.sample_per_part, self.spatial_scale, self.trans_std, )