# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import hashlib
import os
import sys
import types
from typing import Any, List
from urllib.parse import urlparse
from megengine.utils.http_download import download_from_url
from ..distributed import is_distributed
from ..logger import get_logger
from ..serialization import load as _mge_load_serialized
from .const import (
DEFAULT_CACHE_DIR,
DEFAULT_GIT_HOST,
DEFAULT_PROTOCOL,
ENV_MGE_HOME,
ENV_XDG_CACHE_HOME,
HTTP_READ_TIMEOUT,
HUBCONF,
HUBDEPENDENCY,
)
from .exceptions import InvalidProtocol
from .fetcher import GitHTTPSFetcher, GitSSHFetcher
from .tools import cd, check_module_exists, load_module
logger = get_logger(__name__)
PROTOCOLS = {
"HTTPS": GitHTTPSFetcher,
"SSH": GitSSHFetcher,
}
def _get_megengine_home() -> str:
r"""MGE_HOME setting complies with the XDG Base Directory Specification"""
megengine_home = os.path.expanduser(
os.getenv(
ENV_MGE_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
)
)
return megengine_home
def _get_repo(
git_host: str,
repo_info: str,
use_cache: bool = False,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
) -> str:
if protocol not in PROTOCOLS:
raise InvalidProtocol(
"Invalid protocol, the value should be one of {}.".format(
", ".join(PROTOCOLS.keys())
)
)
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
with cd(cache_dir):
fetcher = PROTOCOLS[protocol]
repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
return os.path.join(cache_dir, repo_dir)
def _check_dependencies(module: types.ModuleType) -> None:
if not hasattr(module, HUBDEPENDENCY):
return
dependencies = getattr(module, HUBDEPENDENCY)
if not dependencies:
return
missing_deps = [m for m in dependencies if not check_module_exists(m)]
if len(missing_deps):
raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
def _init_hub(
repo_info: str,
git_host: str,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
):
r"""Imports hubmodule like python import.
Args:
repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
git_host: host address of git repo. Eg: github.com
use_cache: whether to use locally cached code or completely re-fetch.
commit: commit id on github or gitlab.
protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH.
Returns:
a python module.
"""
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
os.makedirs(cache_dir, exist_ok=True)
absolute_repo_dir = _get_repo(
git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
)
sys.path.insert(0, absolute_repo_dir)
hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
sys.path.remove(absolute_repo_dir)
return hubmodule
[文档]@functools.wraps(_init_hub)
def import_module(*args, **kwargs):
return _init_hub(*args, **kwargs)
[文档]def list(
repo_info: str,
git_host: str = DEFAULT_GIT_HOST,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
) -> List[str]:
r"""Lists all entrypoints available in repo hubconf.
Args:
repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
git_host: host address of git repo. Eg: github.com
use_cache: whether to use locally cached code or completely re-fetch.
commit: commit id on github or gitlab.
protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH.
Returns:
all entrypoint names of the model.
"""
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
return [
_
for _ in dir(hubmodule)
if not _.startswith("__") and callable(getattr(hubmodule, _))
]
[文档]def load(
repo_info: str,
entry: str,
*args,
git_host: str = DEFAULT_GIT_HOST,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
**kwargs
) -> Any:
r"""Loads model from github or gitlab repo, with pretrained weights.
Args:
repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
entry: an entrypoint defined in hubconf.
git_host: host address of git repo. Eg: github.com
use_cache: whether to use locally cached code or completely re-fetch.
commit: commit id on github or gitlab.
protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH.
Returns:
a single model with corresponding pretrained weights.
"""
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
_check_dependencies(hubmodule)
module = getattr(hubmodule, entry)(*args, **kwargs)
return module
[文档]def help(
repo_info: str,
entry: str,
git_host: str = DEFAULT_GIT_HOST,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
) -> str:
r"""This function returns docstring of entrypoint ``entry`` by following steps:
1. Pull the repo code specified by git and repo_info.
2. Load the entry defined in repo's hubconf.py
3. Return docstring of function entry.
Args:
repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
entry: an entrypoint defined in hubconf.py
git_host: host address of git repo. Eg: github.com
use_cache: whether to use locally cached code or completely re-fetch.
commit: commit id on github or gitlab.
protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH.
Returns:
docstring of entrypoint ``entry``.
"""
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
doc = getattr(hubmodule, entry).__doc__
return doc
[文档]def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
"""Loads MegEngine serialized object from the given URL.
If the object is already present in ``model_dir``, it's deserialized and
returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
Args:
url: url to serialized object.
model_dir: dir to cache target serialized file.
Returns:
loaded object.
"""
if model_dir is None:
model_dir = os.path.join(_get_megengine_home(), "serialized")
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
# use hash as prefix to avoid filename conflict from different urls
sha256 = hashlib.sha256()
sha256.update(url.encode())
digest = sha256.hexdigest()[:6]
filename = digest + "_" + filename
cached_file = os.path.join(model_dir, filename)
logger.info(
"load_serialized_obj_from_url: download to or using cached %s", cached_file
)
if not os.path.exists(cached_file):
if is_distributed():
logger.warning(
"Downloading serialized object in DISTRIBUTED mode\n"
" File may be downloaded multiple times. We recommend\n"
" users to download in single process first."
)
download_from_url(url, cached_file, HTTP_READ_TIMEOUT)
state_dict = _mge_load_serialized(cached_file)
return state_dict
[文档]class pretrained:
r"""Decorator which helps to download pretrained weights from the given url.
For example, we can decorate a resnet18 function as follows
.. code-block::
@hub.pretrained("https://url/to/pretrained_resnet18.pkl")
def resnet18(**kwargs):
Returns:
When decorated function is called with ``pretrained=True``, MegEngine will automatically
download and fill the returned model with pretrained weights.
"""
[文档] def __init__(self, url):
self.url = url
def __call__(self, func):
@functools.wraps(func)
def pretrained_model_func(
pretrained=False, **kwargs
): # pylint: disable=redefined-outer-name
model = func(**kwargs)
if pretrained:
weights = load_serialized_obj_from_url(self.url)
model.load_state_dict(weights)
return model
return pretrained_model_func
__all__ = [
"list",
"load",
"help",
"load_serialized_obj_from_url",
"pretrained",
"import_module",
]