megengine.hub.hub 源代码

# -*- coding: utf-8 -*-
import functools
import hashlib
import os
import sys
import types
from typing import Any, List
from urllib.parse import urlparse

from megengine.utils.http_download import download_from_url

from ..distributed import is_distributed
from ..logger import get_logger
from ..serialization import load as _mge_load_serialized
from .const import (
    DEFAULT_CACHE_DIR,
    DEFAULT_GIT_HOST,
    DEFAULT_PROTOCOL,
    ENV_MGE_HOME,
    ENV_XDG_CACHE_HOME,
    HUBCONF,
    HUBDEPENDENCY,
)
from .exceptions import InvalidProtocol
from .fetcher import GitHTTPSFetcher, GitSSHFetcher
from .tools import cd, check_module_exists, load_module

logger = get_logger(__name__)


PROTOCOLS = {
    "HTTPS": GitHTTPSFetcher,
    "SSH": GitSSHFetcher,
}


def _get_megengine_home() -> str:
    r"""MGE_HOME setting complies with the XDG Base Directory Specification"""
    megengine_home = os.path.expanduser(
        os.getenv(
            ENV_MGE_HOME,
            os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "megengine"),
        )
    )
    return megengine_home


def _get_repo(
    git_host: str,
    repo_info: str,
    use_cache: bool = False,
    commit: str = None,
    protocol: str = DEFAULT_PROTOCOL,
) -> str:
    if protocol not in PROTOCOLS:
        raise InvalidProtocol(
            "Invalid protocol, the value should be one of {}.".format(
                ", ".join(PROTOCOLS.keys())
            )
        )
    cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
    with cd(cache_dir):
        fetcher = PROTOCOLS[protocol]
        repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
        return os.path.join(cache_dir, repo_dir)


def _check_dependencies(module: types.ModuleType) -> None:
    if not hasattr(module, HUBDEPENDENCY):
        return

    dependencies = getattr(module, HUBDEPENDENCY)
    if not dependencies:
        return

    missing_deps = [m for m in dependencies if not check_module_exists(m)]
    if len(missing_deps):
        raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))


def _init_hub(
    repo_info: str,
    git_host: str,
    use_cache: bool = True,
    commit: str = None,
    protocol: str = DEFAULT_PROTOCOL,
):
    r"""Imports hubmodule like python import.

    Args:
        repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional
            tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"``
        git_host: host address of git repo. Eg: github.com
        use_cache: whether to use locally cached code or completely re-fetch.
        commit: commit id on github or gitlab.
        protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github.
            The value should be one of HTTPS, SSH.

    Returns:
        a python module.
    """
    cache_dir = os.path.expanduser(os.path.join(_get_megengine_home(), "hub"))
    os.makedirs(cache_dir, exist_ok=True)
    absolute_repo_dir = _get_repo(
        git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
    )
    sys.path.insert(0, absolute_repo_dir)
    hubmodule = load_module(HUBCONF, os.path.join(absolute_repo_dir, HUBCONF))
    sys.path.remove(absolute_repo_dir)

    return hubmodule


[文档]@functools.wraps(_init_hub) def import_module(*args, **kwargs): return _init_hub(*args, **kwargs)
[文档]def list( repo_info: str, git_host: str = DEFAULT_GIT_HOST, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ) -> List[str]: r"""Lists all entrypoints available in repo hubconf. Args: repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"`` git_host: host address of git repo. Eg: github.com use_cache: whether to use locally cached code or completely re-fetch. commit: commit id on github or gitlab. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. Returns: all entrypoint names of the model. """ hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) return [ _ for _ in dir(hubmodule) if not _.startswith("__") and callable(getattr(hubmodule, _)) ]
[文档]def load( repo_info: str, entry: str, *args, git_host: str = DEFAULT_GIT_HOST, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, **kwargs ) -> Any: r"""Loads model from github or gitlab repo, with pretrained weights. Args: repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"`` entry: an entrypoint defined in hubconf. git_host: host address of git repo. Eg: github.com use_cache: whether to use locally cached code or completely re-fetch. commit: commit id on github or gitlab. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. Returns: a single model with corresponding pretrained weights. """ hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) _check_dependencies(hubmodule) module = getattr(hubmodule, entry)(*args, **kwargs) return module
[文档]def help( repo_info: str, entry: str, git_host: str = DEFAULT_GIT_HOST, use_cache: bool = True, commit: str = None, protocol: str = DEFAULT_PROTOCOL, ) -> str: r"""This function returns docstring of entrypoint ``entry`` by following steps: 1. Pull the repo code specified by git and repo_info. 2. Load the entry defined in repo's hubconf.py 3. Return docstring of function entry. Args: repo_info: a string with format ``"repo_owner/repo_name[:tag_name/:branch_name]"`` with an optional tag/branch. The default branch is ``master`` if not specified. Eg: ``"brain_sdk/MegBrain[:hub]"`` entry: an entrypoint defined in hubconf.py git_host: host address of git repo. Eg: github.com use_cache: whether to use locally cached code or completely re-fetch. commit: commit id on github or gitlab. protocol: which protocol to use to get the repo, and HTTPS protocol only supports public repo on github. The value should be one of HTTPS, SSH. Returns: docstring of entrypoint ``entry``. """ hubmodule = _init_hub(repo_info, git_host, use_cache, commit, protocol) if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) doc = getattr(hubmodule, entry).__doc__ return doc
[文档]def load_serialized_obj_from_url(url: str, model_dir=None) -> Any: """Loads MegEngine serialized object from the given URL. If the object is already present in ``model_dir``, it's deserialized and returned. If no ``model_dir`` is specified, it will be ``MGE_HOME/serialized``. Args: url: url to serialized object. model_dir: dir to cache target serialized file. Returns: loaded object. """ if model_dir is None: model_dir = os.path.join(_get_megengine_home(), "serialized") os.makedirs(model_dir, exist_ok=True) parts = urlparse(url) filename = os.path.basename(parts.path) # use hash as prefix to avoid filename conflict from different urls sha256 = hashlib.sha256() sha256.update(url.encode()) digest = sha256.hexdigest()[:6] filename = digest + "_" + filename cached_file = os.path.join(model_dir, filename) logger.info( "load_serialized_obj_from_url: download to or using cached %s", cached_file ) if not os.path.exists(cached_file): if is_distributed(): logger.warning( "Downloading serialized object in DISTRIBUTED mode\n" " File may be downloaded multiple times. We recommend\n" " users to download in single process first." ) download_from_url(url, cached_file) state_dict = _mge_load_serialized(cached_file) return state_dict
[文档]class pretrained: r"""Decorator which helps to download pretrained weights from the given url. Including fs, s3, http(s). For example, we can decorate a resnet18 function as follows .. code-block:: @hub.pretrained("https://url/to/pretrained_resnet18.pkl") def resnet18(**kwargs): Returns: When decorated function is called with ``pretrained=True``, MegEngine will automatically download and fill the returned model with pretrained weights. """ def __init__(self, url): self.url = url def __call__(self, func): @functools.wraps(func) def pretrained_model_func( pretrained=False, **kwargs ): # pylint: disable=redefined-outer-name model = func(**kwargs) if pretrained: weights = load_serialized_obj_from_url(self.url) model.load_state_dict(weights) return model return pretrained_model_func
__all__ = [ "list", "load", "help", "load_serialized_obj_from_url", "pretrained", "import_module", ]