megengine.amp.convert_format 源代码
from copy import deepcopy
from .. import functional as F
from ..core import _config
from ..module import Module
from ..tensor import Tensor
def _is_nchw_format(param: Tensor):
# TODO: use better condition
return (param.ndim == 4 or param.ndim == 5) and param.format != "nhwc"
[文档]def convert_tensor_format(x: Tensor, inplace: bool = True):
"""Convert NCHW Tensor to NHWC Tensor."""
if not _is_nchw_format(x):
return x
if x.ndim != 4 and x.ndim != 5:
raise ValueError("Unsupport tensor ndim {}".format(x.ndim))
if x.format != "nhwc":
# hostvalue should still be valid, so no d2h cost.
data = x.numpy()
if inplace:
# reset will destroy existed backward grad
x[...] = Tensor(data, format="nhwc")
else:
# use mge interface to maintain grad
x = Tensor(data, format="nhwc")
return x
[文档]def convert_module_format(module: Module, inplace: bool = True):
"""Convert NCHW Module to NHWC Module."""
if not inplace:
module = deepcopy(module)
for name, param in module.named_tensors():
convert_tensor_format(param, inplace=True)
return module