Save and Load Models (S&L)#
In the process of model development, we often encounter situations where we need to save (Save) and load (Load) models, such as:
In order to avoid training interruption caused by force majeure, it is necessary to develop the good habit of saving the model every certain period of training (Epoch);
At the same time, if the training time is too long, the model may be overfitted on the training data set, so it is necessary to save multiple checkpoints and obtain the optimal result;
In some cases, we need to load the parameters and other required information of the pre-trained model, resume training or fine-tune…
The pickle
module that comes with Python is encapsulated in MegEngine to implement binary serialization and deserialization of Python object structures (such as Module objects). The core interfaces that need to be known to us are megengine.save
and megengine.load
:
>>> megengine.save(model, PATH)
>>> model = megengine.load(PATH)
The above syntax is very concise and intuitive to save and load the entire model
model, but it is not recommended. A more recommended approach is to save and load state_dict
objects, or use checkpointing techniques. The following will explain the above in more detail, and provide some best practices for saving and loading models in some scenarios. You can skip the concepts you are already familiar with and jump directly to the desired use case code demonstration.
Not recommended under any circumstances ❌ |
|
Suitable for inference ✅ Does not meet recovery training requirements 😅 |
|
Suitable for inference or recovery training 💡 |
|
Export static graph models (Dump) |
It is suitable for inference and pursues high-performance deployment 🚀 |
Note
When using the pickle
module, the corresponding terms are also called pickling and unpickling.
The pickle module is compatible with the protocol
Since the data stream format protocol used by the pickle
module may be different between different versions of Python, the MegEngine model saved in a higher version of Python may not be loaded in a lower version of Python. There are two solutions here:
When calling
megengine.save
, specify a more compatible version (such as version 4) through the parameterpickle_protocol
;Interfaces
megengine.save
andmegengine.load
both support passing in thepickle_module
parameter to use the specifiedpickle
module, such as installing and using pickle5 instead of the built-in Pythonpickle
module:>>> import pickle5 as pickle
The pickle module is not safe!
A well-meaning person can execute arbitrary code when unpacked by constructing malicious
pickle
data;Therefore, never unblock data from untrusted sources and data that may have been tampered with.
Below is the ConvNet
model we used for example:
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
class ConvNet(M.Module):
def __init__(self):
super().__init__()
self.conv1 = M.Conv2d(1, 10, 5)
self.pool1 = M.MaxPool2d(2, 2)
self.conv2 = M.Conv2d(10, 20, 5)
self.pool2 = M.MaxPool2d(2, 2)
self.fc1 = M.Linear(320, 50)
self.fc2 = M.Linear(50, 10)
def forward(self, input):
x = self.pool1(F.relu(self.conv1(input)))
x = self.pool2(F.relu(self.conv2(x)))
x = F.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
model = ConvNet()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.1)
save/load entire model#
save:
>>> megengine.save(model, PATH)
load:
>>> model = megengine.load(PATH)
>>> model.eval()
Note
The reason we do not recommend using this method is due to the limitations of :itself. For a specific class, such as a ``ConvNet
model class designed by the user, pickle
does not save the model. will serialize the model class itself, but will instead bind the class to the path containing the source code for its definition, such as project/model.py
. This path is required by pickle
when loading the model . So if you refactor the project later in the development process (for example, rename model.py
), it will cause the model loading step to fail.
Warning
If you still use this method to load the model and try to infer, remember to switch to evaluation mode by calling model.eval()
first.
save/load model state dictionary#
save:
>>> megengine.save(model.state_dict(), PATH)
load:
>>> model = ConvNet()
>>> model.load_state_dict(megengine.load(PATH))
>>> model.eval()
When saving a model for inference only, the necessary processing is to save the learned parameters of the model. Rather than saving the entire model, it is recommended to save the model’s state dictionary state_dict
, which will be more flexible when restoring the model later.
Warning
Compared to loading the entire model, the result obtained by
megengine.load()
is a state dictionary object, so it is necessary to further load the state dictionary into the model through themodel.load_state_dict()
method.model = megengine.load(PATH)
cannot be used in ` Deserialize the state dictionary and pass it to themodel.load_state_dict()
method;After loading the state dictionary successfully, remember to call
model.eval()
to switch the model to evaluation mode.
Note
通常我们约定使用 .pkl
文件扩展名保存模型,如 mge_checkpoint_xxx.pkl
形式。
注意 .pkl
与 .mge
文件的区别
.mge
文件通常是 MegEngine 模型经过 Export serialized model file (Dump) 得到的文件,用于推理部署。
what is a state dictionary#
Due to the limitation of path impact when using pickle
to directly save/load entire model, we need to consider using the native Python data structure to record the state information inside the model, which is convenient for serialization and Deserialize. In Module base class concept and interface introduction, we mentioned that each Module has a state dictionary member, which records the Tensor information inside the model (ie Parameter and Buffer members):
>>> for tensor in model.state_dict():
... print(tensor, "\t", model.state_dict()[tensor].shape)
conv1.bias (1, 10, 1, 1)
conv1.weight (10, 1, 5, 5)
conv2.bias (1, 20, 1, 1)
conv2.weight (20, 10, 5, 5)
fc1.bias (50,)
fc1.weight (50, 320)
fc2.bias (10,)
fc2.weight (10, 50)
The state dictionary is a simple Python dictionary object, so it can be easily saved and loaded with the help of pickle
.
Note
Each optimizer Optimzer
also has a state dictionary, which contains information about the state of the optimizer, and the hyperparameters used; if there is a subsequent need to restore the model and continue training, just saving the model’s state dictionary is not enough — — We also need to save information such as the optimizer’s state dictionary, which is the “checkpoint” technique mentioned below.
See also
Further explanation about state dictionary: Module state dictionary / Optimizer state dictionary
save/load checkpoint#
save:
megengine.save({
"epoch": epoch,
"state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
...
}, PATH)
load:
model = ConvNet()
optimizer = optim.SGD()
checkpoint = megengine.load(PATH)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
model.eval()
# - or -
model.train()
The purpose of saving checkpoints is to be able to restore to the same state as the training time: need to restore not only Module state dictionary ,:ref:optimizer-state-dict. According to actual needs, you can also record the training achieved
epoch
and the latestloss
information.After the checkpoint is loaded, set the model to train or evaluation mode, depending on whether you want to continue training or use it for inference.
Warning
Saving a full checkpoint will take up more disk space than just saving the model’s state dictionary. So you don’t have to save checkpoints if you’re pretty sure you only need to do model inference in the future. Or set a different saving frequency, such as saving a state dictionary every 10 Epochs, and saving a full checkpoint every 100 Epochs, depending on your actual needs.
See also
Refer to how to save and load checkpoints in the official ResNet model:
official/vision/classification/resnet
The relevant interface can be found in train/test/inference.py
.
Export static graph models#
In order to deploy the final trained model to the production environment, the last step of model development requires exporting a static graph:
from megengine import jit
model = ConvNet()
model.load_state_dict(megengine.load(PATH))
model.eval()
@jit.trace(symbolic=True, capture_as_const=True)
def infer_func(data, *, model):
pred = model(data)
pred_normalized = F.softmax(pred)
return pred_normalized
data = megengine.Tensor(np.random.randn(1, 1, 28, 28))
output = infer_func(data, model=model)
infer_func.dump(PATH, arg_names=["data"])
See also
See: Export serialized model file (Dump) for a more specific explanation.