megengine.functional.nn.ctc_loss

ctc_loss(pred, pred_lengths, label, label_lengths, blank=0, reduction='mean')[源代码]

计算 Connectionist Temporal Classification loss 。

参数
  • pred (Tensor) – 概率张量,其尺寸为 (T, N, C),其中 T 是 input 长度,N 是 batch 个数,C 是类别数量(包括 blank)。

  • pred_lengths (Tensor) – pred 中每个序列的点数,尺寸为 (N, )。

  • label (Tensor) – groundtruth 标签,包含每个序列的每个点的 groundtruth 的位置,blank 不应包含在其中。尺寸是 (N, S) 或者 sum(label_lengths))。

  • label_lengths (Tensor) – groundtruth 的每个序列的点数,尺寸是 (N, )。

  • blank (int) – blank 的个数,默认值为 0。

  • reduction (str) – 计算输出的模式:none | mean | sum。默认值为:mean

返回类型

Tensor

返回

损失值。

实际案例

from megengine import tensor
import megengine.functional as F

pred = tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]])
pred_length = tensor([2,2])
label = tensor([1,1])
label_lengths = tensor([1,1])
loss = F.nn.ctc_loss(pred, pred_length, label, label_lengths)
print(loss.numpy())

输出:

0.1504417