# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import heapq
import itertools
import typing
import weakref
import numpy as np
import megengine as mge
from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
_grad_count = 0
_grad_manager_dict = weakref.WeakValueDictionary()
[文档]def get_grad_managers():
return [_grad_manager_dict[key] for key in _grad_manager_dict]
[文档]class GradKey(core2.GradKey):
def __init__(self, name=None):
if name:
self.name = name
[文档] def backward(self, ys, dys):
return core2.backward(self, ys, dys)
[文档]class Grad:
def __init__(self, name=None):
global _grad_count
if name is None:
name = "grad_%d" % _grad_count
_grad_count += 1
self._refkeeper = []
self._impl = GradKey(name)
_grad_manager_dict[self._name] = self
@property
def _priority(self):
return self._impl.priority
@_priority.setter
def _priority(self, priority):
self._impl.priority = priority
@property
def _name(self):
return self._impl.name
def _is_attached_to(self, tensor):
return self._impl.is_attached_to(tensor)
[文档] def wrt(self, *tensors, callback=None):
for x in tensors:
self._impl.attach(x, callback)
return self
def __call__(self, ys, dys):
from collections.abc import Sequence
if not isinstance(ys, Sequence):
ys = [ys]
if not isinstance(dys, Sequence):
dys = [dys]
self._impl.backward(ys, dys)
self._refkeeper = None
def __enter__(self):
return self
def __exit__(self, _1, _2, _3):
self._refkeeper = None
del self._impl
[文档]class Function(ops.PyOpBase):
r"""Defines a block of operations with customizable differentiation.
The computation should be defined in ``forward`` method, with gradient
computation defined in ``backward`` method.
Each instance of ``Function`` should be used only once during forwardding.
Examples:
.. code-block::
class Sigmoid(Function):
def forward(self, x):
y = 1 / (1 + F.exp(-x))
self.y = y
return y
def backward(self, dy):
y = self.y
"""
[文档] def forward(self, *args, **kwargs):
r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
Args:
input: input tensors.
Returns:
a tuple of Tensor or a single Tensor.
Note:
* This method should return a tuple of Tensor or a single Tensor representing the output
of the function.
* positional arguments should all be Tensor
"""
raise NotImplementedError
[文档] def backward(self, *output_grads):
r"""Compute the gradient of the forward function. It must be overriden by all subclasses.
Args:
output_grads: gradients of outputs that are returned by :meth:`forward`.
Note:
* In case when some tensors of outputs are not related to loss function, the corresponding
values in ``output_grads`` would be ``None``.
* This method should return a tuple which containing the gradients of all inputs, in the same order
as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned
instead if there is only one input. If users want to stop the propagation of some gradients,
the corresponding returned values should be set ``None`` .
"""
raise NotImplementedError
def _default_rule(self, *args):
ret = self.forward(*args)
self.__single_output = isinstance(ret, core2.Tensor)
return ret
def _grad_rule(self, *args):
return self._default_rule(*args), self.backward
def __call__(self, *args):
ret = core2.apply(self, *args)
if self.__single_output:
(ret,) = ret
return ret
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)