LAMB

class LAMB(params, lr, betas=(0.9, 0.999), eps=1e-08, bias_correction=True, weight_decay=0.0, always_adapt=False)[源代码]

实现 LAMB 算法。

LAMB 提出于 “Large Batch Optimization for Deep Learning: Training BERT in 76 minutes”

参数
  • params (Union[Iterable[Parameter], dict]) – 可迭代对象,可以是一组待优化的参数,或定义几组参数的dict类型。

  • lr (float) – 学习率(learning rate)。

  • betas (Tuple[float, float]) – 用于计算梯度和其平方的滑动平均的系数。默认值: (0.9, 0.999)

  • eps (float) – 加到分母以提高数值稳定性的一项。默认值: 1e-8

  • bias_correction (bool) – 使用 1 - beta ** step 进行偏差修正。默认值: True

  • weight_decay (float) – 权重衰减(L2 惩罚项)。默认值: 0.0

  • always_adapt (bool) – 对 0.0 权重衰减参数应用自适应学习率