Use Hub to publish and load pre-trained models#
With the help of the hub
module provided by MegEngine, researchers can easily:
Publish your own open source pretrained models on GitHub or GitLab by adding a
hubconf.py
file;Load pre-trained models released by other researchers through the ` :func:interface, which is conducive to the reproduction of research results;
The loaded pretrained model can also be used as fine-tuning for transfer learning, or for prediction.
This section will use the ResNet series model as an example to show the model publishing and loading process.
Note
The “pre-trained model” here includes 1. the definition of the model 2. the pre-trained weights.
Compared to using Module.load_state_dict
and megengine.load
to deserialize and load the model’s state dictionary, megengine.hub.load
can also complete the process of loading model definitions for the user before this ( According to hubconf.py
), the loaded model can also be used in Export serialized model file (Dump) for high performance inference deployment scenarios.
Note
The related functions of the Hub can also be used as an internal Git server, and corresponding parameters need to be configured when using the related interfaces.
See also
The official website of MegEngine provides the Model Center <https://megengine.org.cn/model-hub>section, which is based on the leading deep learning algorithm of the Megvii Research Institute and provides pre-training models that meet multiple business scenarios. In fact, :models:`hubconf.py configuration has been added to the official model library Models. If you wish to publish your research models to the official MegEngine model center, please refer to the README file of the Hub repository.
Publish pretrained models#
In the hubconf.py
file, at least one entry point (Entrypoint) needs to be provided, in the form of:
def entrypoint_name(*args, **kwargs):
"""Returns a model."""
...
When calling the entry point, it usually returns a model ( M.Module ), or other objects that you want to load through the Hub;
*args
and**kwargs
arguments will be passed to the real callable when loading the model;The docstring for the entry point is displayed when the
hub.help
interface is called.
provide entry point#
Take the official ResNet model as an example, the model definition file is in official/vision/classification/resnet/model.py.
Models
├── official/vision/classification/resnet
│ └── model.py
└── hunconf.py
We can implement an entry point:in hubconf.py
like this
from official.vision.classification.resnet.model import BasicBlock, Bottleneck, ResNet
def resnet18(**kwargs):
"""Resnet18 model"""
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
However, because ResNet’s model.py
has already defined resnet18
, resnet34
, resnet50
, resnet101
, resnet152
and other common network structures according to the Hub styles are defined, so in actual hubconf.py only need to import them:
from official.vision.classification.resnet.model import (
...
resnet18,
...
)
Provides pretrained weights#
Identifies the URL address of the pretrained weights by adding hub.pretrained
decorator to the entry:
@hub.pretrained("https://url/to/pretrained_resnet18.pkl")
def resnet18(pretrained=False, **kwargs):
"""Resnet18 model"""
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
When the decorated function has the parameter
pretrained=True
, it will automatically download and fill the returned model with the pretrained weights when it is called;Pre-training weights can exist in the Git repository. For open source projects on GitHub/GitLab, you need to consider the overall size of the pre-training weights and the user’s download conditions, which can be judged according to the actual situation - choose to attach the pre-training weights to the model Publish together, or in other locations (such as network disk, OSS, etc.).
Load the pre-trained model#
The pre-trained model entries provided in hubconf.py
in the specified GitHub repository can be listed through the ` :func:interface.
For example, run the following command to find all published pre-trained models:in the MegEngine/Models repository on GitHub
>>> megengine.hub.list("megengine/models")
['ATSS',
'BasicBlock',
# ...
'resnet18',
# ...
'wwm_cased_L_24_H_1024_A_16',
'wwm_uncased_L_24_H_1024_A_16']
Assuming what we need is the resnet18
pretrained model, use the hub.help
interface, you can view the docstring information for the corresponding entry point:
>>> megengine.hub.help("megengine/models", "resnet18")
'ResNet-18 model...'
Only need to use hub.load
interface, you can complete the loading of the corresponding pre-training model at one:
>>> model = megengine.hub.load('megengine/models', 'resnet18', pretrained=True)
>>> model.eval()
Warning
Before inference, remember to call model.eval()
to switch the model to evaluation mode.
Note
By default, files such as ``:will be pulled from the master branch of the corresponding GitHub repository.
It can be specified to the dev branch name (or tag name) in the form of
megengine/models:dev
;You can choose to use the specified Git server by setting the
git_host
parameter;You can choose to use the specified commit position by setting the
commit
parameter;By setting the
protocol
parameter, you can choose the protocol used to obtain the code repository.
Normally, no additional setup is required, and you can clone from a public GitHub repository with HTTPS protocol. If you configure it specifically (such as using an internal Git server), make sure you have access to the corresponding code repository.