megengine.core.autodiff.grad 源代码

# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
import functools
import heapq
import itertools
import typing
import weakref

import numpy as np

import megengine as mge

from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const

_grad_count = 0
_grad_manager_dict = weakref.WeakValueDictionary()

[文档]def get_grad_managers(): return [_grad_manager_dict[key] for key in _grad_manager_dict]
[文档]class GradKey(core2.GradKey): def __init__(self, name=None): if name: = name
[文档] def backward(self, ys, dys): return core2.backward(self, ys, dys)
[文档]class Grad: def __init__(self, name=None): global _grad_count if name is None: name = "grad_%d" % _grad_count _grad_count += 1 self._refkeeper = [] self._impl = GradKey(name) _grad_manager_dict[self._name] = self @property def _priority(self): return self._impl.priority @_priority.setter def _priority(self, priority): self._impl.priority = priority @property def _name(self): return def _is_attached_to(self, tensor): return self._impl.is_attached_to(tensor)
[文档] def wrt(self, *tensors, callback=None): for x in tensors: self._impl.attach(x, callback) return self
def __call__(self, ys, dys): from import Sequence if not isinstance(ys, Sequence): ys = [ys] if not isinstance(dys, Sequence): dys = [dys] self._impl.backward(ys, dys) self._refkeeper = None def __enter__(self): return self def __exit__(self, _1, _2, _3): self._refkeeper = None del self._impl
[文档]class Function(ops.PyOpBase): r"""Defines a block of operations with customizable differentiation. The computation should be defined in ``forward`` method, with gradient computation defined in ``backward`` method. Each instance of ``Function`` should be used only once during forwardding. Examples: .. code-block:: class Sigmoid(Function): def forward(self, x): y = 1 / (1 + F.exp(-x)) self.y = y return y def backward(self, dy): y = self.y """
[文档] def forward(self, *args, **kwargs): r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. Args: input: input tensors. Returns: a tuple of Tensor or a single Tensor. Note: * This method should return a tuple of Tensor or a single Tensor representing the output of the function. * positional arguments should all be Tensor """ raise NotImplementedError
[文档] def backward(self, *output_grads): r"""Compute the gradient of the forward function. It must be overriden by all subclasses. Args: output_grads: gradients of outputs that are returned by :meth:`forward`. Note: * In case when some tensors of outputs are not related to loss function, the corresponding values in ``output_grads`` would be ``None``. * This method should return a tuple which containing the gradients of all inputs, in the same order as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned instead if there is only one input. If users want to stop the propagation of some gradients, the corresponding returned values should be set ``None`` . """ raise NotImplementedError
def _default_rule(self, *args): ret = self.forward(*args) self.__single_output = isinstance(ret, core2.Tensor) return ret def _grad_rule(self, *args): return self._default_rule(*args), self.backward def __call__(self, *args): ret = core2.apply(self, *args) if self.__single_output: (ret,) = ret return ret def __getstate__(self): return self.__dict__ def __setstate__(self, state): self.__dict__.update(state)