megengine.autodiff.GradManager.attach

GradManager.attach(tensors, callbacks=None)[源代码]

指示 GradManager 跟踪张量上的操作,以便那些张量上的梯度,可以在之后进行计算。

attach 也接受一个回调函数的列表,在 backward 的过程中,这些回调函数会被以 Tensor 和其梯度作为参数调用 。回调函数的签名应该如下:

def callback(tensor: Tensor, grad: Tensor) -> Tensor:
    ...
    # returned grad is passed to subsequent callbacks
    # and finally accumulated to the .grad attribute of tensor
    return grad

多次 attach 调用的 Tensor 列表如果有重叠,那么这些 Tensor 对应的回调函数列表会被拼接,此操作对每个 Tensor 独立作用。例如,

gm.attach([x, y], callbacks=[f])
gm.attach([y], callbacks=[g])

等价于

gm.attach([x], callbacks=[f])
gm.attach([y], callbacks=[f, g])

调用 attach 之后,不仅会在当此求导中生效,也会同一个 GradManager 之后的所有求导中生效。在用同一个 GradManager 进行多次求导时,用相同的参数反复调用 attach 很可能是一个错误,这会导致回调函数的列表的长度一直增长。

注解

在重复使用同一个 GradManager 的同时,您可能会希望对一些临时的 Tensor 求导,例如对神经网络的输入求导。考虑到这种用法,GradManager 会对 attached Tensor 持有弱引用。在大多数时候,这已经可以避免资源泄漏。但是,仍然有少数情况需要您注意:

  • 回调函数不应该持有被 attached Tensor 的强引用,无论是直接地还是间接地。任何强引用,包括来自回调函数的强引用,都会导致 attached Tensor 无法被垃圾回收(即使运行可回收引用循环的完整垃圾回收!),直到 GradManager 对象本身被垃圾回收为止。

还需注意的一点是 GradManager 如果正在进行求导,可能会持有对被 attached Tensor 的强引用。本注解仅针对将一个 GradManager 用于多次求导可能引发的资源泄漏,并不涉及在进行一次求导时资源是否被第一时间释放的问题。

参数
  • tensors (Iterable[Tensor]) – 需要跟踪的 Tensor 或者 Tensor 列表

  • callbacks – 回调函数或回调函数的列表