megengine.functional.nn.ctc_loss¶
- ctc_loss(pred, pred_lengths, label, label_lengths, blank=0, reduction='mean')[源代码]¶
计算 Connectionist Temporal Classification loss 。
- 参数
pred (
Tensor
) – 概率张量,其尺寸为 (T, N, C),其中 T 是 input 长度,N 是 batch 个数,C 是类别数量(包括 blank)。pred_lengths (
Tensor
) –pred
中每个序列的点数,尺寸为 (N, )。label (
Tensor
) – groundtruth 标签,包含每个序列的每个点的 groundtruth 的位置,blank 不应包含在其中。尺寸是 (N, S) 或者 sum(label_lengths))。label_lengths (
Tensor
) – groundtruth 的每个序列的点数,尺寸是 (N, )。blank (
int
) – blank 的个数,默认值为 0。reduction (
str
) – 计算输出的模式:none | mean | sum。默认值为:mean。
- 返回类型
- 返回
损失值。
实际案例
>>> pred = Tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]]) >>> pred_lengths = Tensor([2, 2]) >>> label = Tensor([1, 1]) >>> label_lengths = Tensor([1, 1]) >>> F.nn.ctc_loss(pred, pred_lengths, label, label_lengths) Tensor(0.1504417, device=xpux:0)