快捷方式

torch.nn.functional.ctc_loss

torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[source]

应用连接时序分类损失。

有关详细信息,请参阅 CTCLoss

注意

在某些情况下,当在 CUDA 设备上提供张量并使用 CuDNN 时,此运算符可能会选择非确定性算法以提高性能。如果这是不可取的,您可以尝试通过设置 torch.backends.cudnn.deterministic = True 来使操作确定性(可能以性能为代价)。有关更多信息,请参阅 可重复性

注意

当在 CUDA 设备上提供张量时,此操作可能会产生非确定性梯度。有关更多信息,请参阅 可重复性

参数
  • log_probs (张量) – (T,N,C)(T, N, C)(T,C)(T, C),其中 C = 字母表中字符的数量,包括空白T = 输入长度,以及 N = 批次大小。输出的对数概率(例如,使用 torch.nn.functional.log_softmax() 获得)。

  • targets (张量) – (N,S)(N, S)(sum(target_lengths))。目标不能是空白。在第二种形式中,假定目标是连接的。

  • input_lengths (张量) – (N)(N)()()。输入的长度(每个都必须小于等于 TT)。

  • target_lengths (张量) – (N)(N)()()。目标的长度

  • blank (int, 可选) – 空白标签。默认值为 00.

  • reduction (str, 可选) – 指定要应用于输出的归约:'none' | 'mean' | 'sum''none':不应用任何归约,'mean':输出损失将除以目标长度,然后取批次的平均值,'sum':输出将求和。默认值:'mean'

  • zero_infinity (bool, 可选) – 是否将无限损失及其关联的梯度置零。默认值:False。无限损失主要发生在输入过短而无法与目标对齐时。

返回类型

张量

示例

>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
>>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward()

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源