megengine.module.batch_matmul_activation 源代码

# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from ..functional import matmul, relu
from ..tensor import Parameter
from . import init
from .module import Module


[文档]class BatchMatMulActivation(Module): r""" Batched MatMul with activation(only relu supported), no transpose anywhere. """ def __init__( self, batch: int, in_features: int, out_features: int, bias: bool = True, nonlinear_mode="IDENTITY", **kwargs ): super().__init__(**kwargs) self.batch = batch self.out_features = out_features self.in_features = in_features w_shape = (batch, out_features, in_features) self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) self.bias = None if bias: b_shape = (out_features,) self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) self.nonlinear_mode = nonlinear_mode self.reset_parameters() def _get_fanin(self): return self.in_features
[文档] def reset_parameters(self) -> None: fanin = self._get_fanin() std = np.sqrt(1 / fanin) init.normal_(self.weight, 0.0, std) if self.bias is not None: init.zeros_(self.bias)
def _calc_linear(self, x, weight, bias): res = matmul(weight, x) if self.bias is not None: res += bias if self.nonlinear_mode == "RELU": res = relu(res) return res
[文档] def forward(self, x): return self._calc_linear(x, self.weight, self.bias)
def _module_info_string(self) -> str: return "batch={}, in_features={}, out_features={}, bias={}".format( self.batch, self.in_features, self.out_features, self.bias is not None )