Use Hub to publish and load pre-trained models

With the help of the hub module provided by MegEngine, researchers can easily:

  • Publish your own open source pretrained models on GitHub or GitLab by adding a hubconf.py file;

  • Load pre-trained models released by other researchers through the ` :func:interface, which is conducive to the reproduction of research results;

  • The loaded pretrained model can also be used as fine-tuning for transfer learning, or for prediction.

This section will use the ResNet series model as an example to show the model publishing and loading process.

Note

The “pre-trained model” here includes 1. the definition of the model 2. the pre-trained weights.

Compared to using Module.load_state_dict and megengine.load to deserialize and load the model’s state dictionary, megengine.hub.load can also complete the process of loading model definitions for the user before this ( According to hubconf.py), the loaded model can also be used in Export serialized model file (Dump) for high performance inference deployment scenarios.

Note

The related functions of the Hub can also be used as an internal Git server, and corresponding parameters need to be configured when using the related interfaces.

See also

The official website of MegEngine provides the Model Center <https://megengine.org.cn/model-hub>section, which is based on the leading deep learning algorithm of the Megvii Research Institute and provides pre-training models that meet multiple business scenarios. In fact, :models:`hubconf.py configuration has been added to the official model library Models. If you wish to publish your research models to the official MegEngine model center, please refer to the README file of the Hub repository.

Publish pretrained models

In the hubconf.py file, at least one entry point (Entrypoint) needs to be provided, in the form of:

def entrypoint_name(*args, **kwargs):
    """Returns a model."""
    ...
  • When calling the entry point, it usually returns a model ( M.Module ), or other objects that you want to load through the Hub;

  • *args and **kwargs arguments will be passed to the real callable when loading the model;

  • The docstring for the entry point is displayed when the hub.help interface is called.

provide entry point

Take the official ResNet model as an example, the model definition file is in official/vision/classification/resnet/model.py.

Models
├── official/vision/classification/resnet
│   └── model.py
└── hunconf.py

We can implement an entry point:in hubconf.py like this

from official.vision.classification.resnet.model import BasicBlock, Bottleneck, ResNet

def resnet18(**kwargs):
    """Resnet18 model"""
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

However, because ResNet’s model.py has already defined resnet18, resnet34, resnet50, resnet101, resnet152 and other common network structures according to the Hub styles are defined, so in actual hubconf.py only need to import them:

from official.vision.classification.resnet.model import (
     ...
     resnet18,
     ...
 )

Provides pretrained weights

Identifies the URL address of the pretrained weights by adding hub.pretrained decorator to the entry:

@hub.pretrained("https://url/to/pretrained_resnet18.pkl")
def resnet18(pretrained=False, **kwargs):
    """Resnet18 model"""
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
  • When the decorated function has the parameter pretrained=True, it will automatically download and fill the returned model with the pretrained weights when it is called;

  • Pre-training weights can exist in the Git repository. For open source projects on GitHub/GitLab, you need to consider the overall size of the pre-training weights and the user’s download conditions, which can be judged according to the actual situation - choose to attach the pre-training weights to the model Publish together, or in other locations (such as network disk, OSS, etc.).

Load the pre-trained model

The pre-trained model entries provided in hubconf.py in the specified GitHub repository can be listed through the ` :func:interface.

For example, run the following command to find all published pre-trained models:in the MegEngine/Models repository on GitHub

>>> megengine.hub.list("megengine/models")
['ATSS',
 'BasicBlock',
 # ...
 'resnet18',
 # ...
 'wwm_cased_L_24_H_1024_A_16',
 'wwm_uncased_L_24_H_1024_A_16']

Assuming what we need is the resnet18 pretrained model, use the hub.help interface, you can view the docstring information for the corresponding entry point:

>>> megengine.hub.help("megengine/models", "resnet18")
'ResNet-18 model...'

Only need to use hub.load interface, you can complete the loading of the corresponding pre-training model at one:

>>> model = megengine.hub.load('megengine/models', 'resnet18', pretrained=True)
>>> model.eval()

Warning

Before inference, remember to call model.eval() to switch the model to evaluation mode.

Note

By default, files such as ``:will be pulled from the master branch of the corresponding GitHub repository.

  • It can be specified to the dev branch name (or tag name) in the form of megengine/models:dev;

  • You can choose to use the specified Git server by setting the git_host parameter;

  • You can choose to use the specified commit position by setting the commit parameter;

  • By setting the protocol parameter, you can choose the protocol used to obtain the code repository.

Normally, no additional setup is required, and you can clone from a public GitHub repository with HTTPS protocol. If you configure it specifically (such as using an internal Git server), make sure you have access to the corresponding code repository.