megengine.functional.nn.binary_cross_entropy

binary_cross_entropy(pred, label, with_logits=True, reduction='mean')[源代码]

计算 binary cross entropy loss(默认使用 logits)。

参数
  • pred (Tensor) – (N,*),其中 * 指任何附加的维度。

  • label (Tensor) – (N,*),与输入的形状相同。

  • with_logits (bool) – 布尔值,是否先应用 sigmoid。默认:True

  • reduction (str) – 对输出应用的规约操作: ‘none’ | ‘mean’ | ‘sum’ 。

返回类型

Tensor

返回

损失值。

实际案例

默认情况下( with_logits 为 True), pred 被认为是 logits,类别概率由 softmax 给出。它的数值稳定性优于依次调用 sigmoidbinary_cross_entropy

>>> pred = Tensor([0.9, 0.7, 0.3])
>>> label = Tensor([1., 1., 1.])
>>> F.nn.binary_cross_entropy(pred, label)
Tensor(0.4328984, device=xpux:0)
>>> F.nn.binary_cross_entropy(pred, label, reduction="none")
Tensor([0.3412 0.4032 0.5544], device=xpux:0)

如果 pred 是概率,将 with_logits 设置为 False:

>>> pred = Tensor([0.9, 0.7, 0.3])
>>> label = Tensor([1., 1., 1.])
>>> F.nn.binary_cross_entropy(pred, label, with_logits=False)
Tensor(0.5553361, device=xpux:0)
>>> F.nn.binary_cross_entropy(pred, label, with_logits=False, reduction="none")
Tensor([0.1054 0.3567 1.204 ], device=xpux:0)