megengine.module.GroupNorm.load_state_dict¶
- GroupNorm.load_state_dict(state_dict, strict=True)¶
向当前模块中加载由
state_dict
创建的给定字典。若strict
为True
,state_dict
的键则必须与state_dict
返回的键准确匹配。为了处理复杂情况,用户可以传入闭包 Function[key: str, var: Tensor] -> Optional[np.ndarray] 作为 state_dict 。例如,欲加载除了最后线性分类器外的所有部分:
state_dict = {...} # Dict[str, np.ndarray] model.load_state_dict({ k: None if k.startswith('fc') else v for k, v in state_dict.items() }, strict=False)
这里返回
None
意味着忽略参数 k 。为了防止形状不匹配(例如加载PyTorch权重),我们可以在加载之前重塑(reshape):
state_dict = {...} def reshape_accordingly(k, v): return state_dict[k].reshape(v.shape) model.load_state_dict(reshape_accordingly)
我们还可以进行原位重初始化或修剪(pruning):
def reinit_and_pruning(k, v): if 'bias' in k: M.init.zero_(v) if 'conv' in k: