megengine.traced_module.TracedModule.load_state_dict¶
- TracedModule.load_state_dict(state_dict, strict=True)¶
Loads a given dictionary created by
state_dict
into this module. Ifstrict
isTrue
, the keys ofstate_dict
must exactly match the keys returned bystate_dict
.Users can also pass a closure:
Function[key: str, var: Tensor] -> Optional[np.ndarray]
as a state_dict, in order to handle complex situations. For example, load everything except for the final linear classifier:state_dict = {...} # Dict[str, np.ndarray] model.load_state_dict({ k: None if k.startswith('fc') else v for k, v in state_dict.items() }, strict=False)
Here returning
None
means skipping parameterk
.To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading:
state_dict = {...} def reshape_accordingly(k, v): return state_dict[k].reshape(v.shape) model.load_state_dict(reshape_accordingly)
We can also perform inplace re-initialization or pruning:
def reinit_and_pruning(k, v): if 'bias' in k: M.init.zero_(v) if 'conv' in k: