megengine.module.init.calculate_fan_in_and_fan_out

calculate_fan_in_and_fan_out(tensor)[source]

Calculates fan_in / fan_out value for given weight tensor. This function assumes input tensor is stored in NCHW format.

Note

The group conv2d kernel shape in MegEngine is (G, O/G, I/G, K, K). This function calculates fan_out = O/G * K * K as default, but PyTorch uses fan_out = O * K * K.

Parameters

tensor (Tensor) – weight tensor in NCHW format.

Return type

Tuple[float, float]