# -*- coding: utf-8 -*-
import functools
import hashlib
import os
import sys
import types
from typing import Any, List
from urllib.parse import urlparse
from megengine.utils.http_download import download_from_url
from ..distributed import is_distributed
from ..logger import get_logger
from ..serialization import load as _mge_load_serialized
from .const import (
DEFAULT_CACHE_DIR,
DEFAULT_GIT_HOST,
DEFAULT_PROTOCOL,
ENV_MGE_HOME,
ENV_XDG_CACHE_HOME,
HUBCONF,
HUBDEPENDENCY,
)
from .exceptions import InvalidProtocol
from .fetcher import GitHTTPSFetcher, GitSSHFetcher
from .tools import cd, check_module_exists, load_module
logger = get_logger(__name__)
PROTOCOLS = {
"HTTPS": GitHTTPSFetcher,
"SSH": GitSSHFetcher,
}
def _get_megengine_home() -> str:
r"""MGE_HOME setting complies with the XDG Base Directory Specification"""
megengine_home = os.path.expanduser(
os.getenv(
ENV_MGE_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
)
)
return megengine_home
def _get_repo(
git_host: str,
repo_info: str,
use_cache: bool = False,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
) -> str:
if protocol not in PROTOCOLS:
raise InvalidProtocol(
"Invalid protocol, the value should be one of {}.".format(
", ".join(PROTOCOLS.keys())
)
)
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
with cd(cache_dir):
fetcher = PROTOCOLS[protocol]
repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
return os.path.join(cache_dir, repo_dir)
def _check_dependencies(module: types.ModuleType) -> None:
if not hasattr(module, HUBDEPENDENCY):
return
dependencies = getattr(module, HUBDEPENDENCY)
if not dependencies:
return
missing_deps = [m for m in dependencies if not check_module_exists(m)]
if len(missing_deps):
raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
def _init_hub(
repo_info: str,
git_host: str,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
):
r"""Imports hubmodule like python import.
Args:
repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
git_host: host address of git repo. Eg: github.com
use_cache: whether to use locally cached code or completely re-fetch.
commit: commit id on github or gitlab.
protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH.
Returns:
a python module.
"""
cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
os.makedirs(cache_dir, exist_ok=True)
absolute_repo_dir = _get_repo(
git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
)
sys.path.insert(0, absolute_repo_dir)
hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
sys.path.remove(absolute_repo_dir)
return hubmodule
[文档]@functools.wraps(_init_hub)
def import_module(*args, **kwargs):
return _init_hub(*args, **kwargs)
[文档]def list(
repo_info: str,
git_host: str = DEFAULT_GIT_HOST,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
) -> List[str]:
r"""Lists all entrypoints available in repo hubconf.
Args:
repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
git_host: host address of git repo. Eg: github.com
use_cache: whether to use locally cached code or completely re-fetch.
commit: commit id on github or gitlab.
protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH.
Returns:
all entrypoint names of the model.
"""
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
return [
_
for _ in dir(hubmodule)
if not _.startswith("__") and callable(getattr(hubmodule, _))
]
[文档]def load(
repo_info: str,
entry: str,
*args,
git_host: str = DEFAULT_GIT_HOST,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
**kwargs
) -> Any:
r"""Loads model from github or gitlab repo, with pretrained weights.
Args:
repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
entry: an entrypoint defined in hubconf.
git_host: host address of git repo. Eg: github.com
use_cache: whether to use locally cached code or completely re-fetch.
commit: commit id on github or gitlab.
protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH.
Returns:
a single model with corresponding pretrained weights.
"""
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
_check_dependencies(hubmodule)
module = getattr(hubmodule, entry)(*args, **kwargs)
return module
[文档]def help(
repo_info: str,
entry: str,
git_host: str = DEFAULT_GIT_HOST,
use_cache: bool = True,
commit: str = None,
protocol: str = DEFAULT_PROTOCOL,
) -> str:
r"""This function returns docstring of entrypoint ``entry`` by following steps:
1. Pull the repo code specified by git and repo_info.
2. Load the entry defined in repo's hubconf.py
3. Return docstring of function entry.
Args:
repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
entry: an entrypoint defined in hubconf.py
git_host: host address of git repo. Eg: github.com
use_cache: whether to use locally cached code or completely re-fetch.
commit: commit id on github or gitlab.
protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
The value should be one of HTTPS, SSH.
Returns:
docstring of entrypoint ``entry``.
"""
hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol)
if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)):
raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry))
doc = getattr(hubmodule, entry).__doc__
return doc
[文档]def load_serialized_obj_from_url(url: str, model_dir=None) -> Any:
"""Loads MegEngine serialized object from the given URL.
If the object is already present in ``model_dir``, it's deserialized and
returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``.
Args:
url: url to serialized object.
model_dir: dir to cache target serialized file.
Returns:
loaded object.
"""
if model_dir is None:
model_dir = os.path.join(_get_megengine_home(), "serialized")
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
# use hash as prefix to avoid filename conflict from different urls
sha256 = hashlib.sha256()
sha256.update(url.encode())
digest = sha256.hexdigest()[:6]
filename = digest + "_" + filename
cached_file = os.path.join(model_dir, filename)
logger.info(
"load_serialized_obj_from_url: download to or using cached %s", cached_file
)
if not os.path.exists(cached_file):
if is_distributed():
logger.warning(
"Downloading serialized object in DISTRIBUTED mode\n"
" File may be downloaded multiple times. We recommend\n"
" users to download in single process first."
)
download_from_url(url, cached_file)
state_dict = _mge_load_serialized(cached_file)
return state_dict
[文档]class pretrained:
r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s).
For example, we can decorate a resnet18 function as follows
.. code-block::
@hub.pretrained("https://url/to/pretrained_resnet18.pkl")
def resnet18(**kwargs):
Returns:
When decorated function is called with ``pretrained=True``, MegEngine will automatically
download and fill the returned model with pretrained weights.
"""
def __init__(self, url):
self.url = url
def __call__(self, func):
@functools.wraps(func)
def pretrained_model_func(
pretrained=False, **kwargs
): # pylint: disable=redefined-outer-name
model = func(**kwargs)
if pretrained:
weights = load_serialized_obj_from_url(self.url)
model.load_state_dict(weights)
return model
return pretrained_model_func
__all__ = [
"list",
"load",
"help",
"load_serialized_obj_from_url",
"pretrained",
"import_module",
]