megengine.functional.nn.binary_cross_entropy¶
- binary_cross_entropy(pred, label, with_logits=True, reduction='mean')[源代码]¶
计算 binary cross entropy loss(默认使用 logits)。
- 参数
- 返回类型
- 返回
损失值。
实际案例
默认情况下(
with_logits
为 True),pred
被认为是 logits,类别概率由 softmax 给出。它的数值稳定性优于依次调用sigmoid
和binary_cross_entropy
。>>> pred = Tensor([0.9, 0.7, 0.3]) >>> label = Tensor([1., 1., 1.]) >>> F.nn.binary_cross_entropy(pred, label) Tensor(0.4328984, device=xpux:0) >>> F.nn.binary_cross_entropy(pred, label, reduction="none") Tensor([0.3412 0.4032 0.5544], device=xpux:0)
如果
pred
是概率,将with_logits
设置为 False:>>> pred = Tensor([0.9, 0.7, 0.3]) >>> label = Tensor([1., 1., 1.]) >>> F.nn.binary_cross_entropy(pred, label, with_logits=False) Tensor(0.5553361, device=xpux:0) >>> F.nn.binary_cross_entropy(pred, label, with_logits=False, reduction="none") Tensor([0.1054 0.3567 1.204 ], device=xpux:0)