megengine.functional.nn.cross_entropy¶
- cross_entropy(pred, label, axis=1, with_logits=True, label_smooth=0, reduction='mean')[源代码]¶
计算 multi-class cross entropy loss(默认使用 logits)。
当使用标签平滑 (label smoothing) 时,标签的分布情况如下:
\[y^{LS}_{k}=y_{k}\left(1-\alpha\right)+\alpha/K\]在上述公式中,\(y^{LS}\) 是平滑后的标签分布,\(y\) 是原有的数据分布。 k 是下标,表明第几个标签. \(\alpha\) 是
label_smooth
平滑系数,\(K\) 是标签的个数。- 参数
- 返回类型
- 返回
损失值。
实际案例
默认情况下(
with_logits
为 True),pred
被认为是 logits,类别概率由 softmax 给出。它的数值稳定性优于依次调用softmax
和binary_cross_entropy
。>>> pred = Tensor([[0., 1.], [0.3, 0.7], [0.7, 0.3]]) >>> label = Tensor([1., 1., 1.]) >>> F.nn.cross_entropy(pred, label) Tensor(0.57976407, device=xpux:0) >>> F.nn.cross_entropy(pred, label, reduction="none") Tensor([0.3133 0.513 0.913 ], device=xpux:0)
如果
pred
是概率,将with_logits
设置为 False:>>> pred = Tensor([[0., 1.], [0.3, 0.7], [0.7, 0.3]]) >>> label = Tensor([1., 1., 1.]) >>> F.nn.cross_entropy(pred, label, with_logits=False) Tensor(0.5202159, device=xpux:0) >>> F.nn.cross_entropy(pred, label, with_logits=False, reduction="none") Tensor([0. 0.3567 1.204 ], device=xpux:0)