megengine.quantization.LSQ.load_state_dict

LSQ.load_state_dict(state_dict, strict=True)

Loads a given dictionary created by state_dict into this module. If strict is True, the keys of state_dict must exactly match the keys returned by state_dict.

Users can also pass a closure: Function[key: str, var: Tensor] -> Optional[np.ndarray] as a state_dict, in order to handle complex situations. For example, load everything except for the final linear classifier:

state_dict = {...}  #  Dict[str, np.ndarray]
model.load_state_dict({
    k: None if k.startswith('fc') else v
    for k, v in state_dict.items()
}, strict=False)

Here returning None means skipping parameter k.

To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading:

state_dict = {...}
def reshape_accordingly(k, v):
    return state_dict[k].reshape(v.shape)
model.load_state_dict(reshape_accordingly)

We can also perform inplace re-initialization or pruning:

def reinit_and_pruning(k, v):
    if 'bias' in k:
        M.init.zero_(v)
    if 'conv' in k: