megengine.module.GroupNorm.load_state_dict

GroupNorm.load_state_dict(state_dict, strict=True)

向当前模块中加载由 state_dict 创建的给定字典。若 strictTruestate_dict 的键则必须与 state_dict 返回的键准确匹配。

为了处理复杂情况,用户可以传入闭包 Function[key: str, var: Tensor] -> Optional[np.ndarray] 作为 state_dict 。例如,欲加载除了最后线性分类器外的所有部分:

state_dict = {...}  #  Dict[str, np.ndarray]
model.load_state_dict({
    k: None if k.startswith('fc') else v
    for k, v in state_dict.items()
}, strict=False)

这里返回 None 意味着忽略参数 k

为了防止形状不匹配(例如加载PyTorch权重),我们可以在加载之前重塑(reshape):

state_dict = {...}
def reshape_accordingly(k, v):
    return state_dict[k].reshape(v.shape)
model.load_state_dict(reshape_accordingly)

我们还可以进行原位重初始化或修剪(pruning):

def reinit_and_pruning(k, v):
    if 'bias' in k:
        M.init.zero_(v)
    if 'conv' in k: