快捷方式

CTCLoss

torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)[源代码][源代码]

连接主义时间分类损失 (Connectionist Temporal Classification loss)。

计算连续(未分段)时间序列与目标序列之间的损失。CTCLoss 对输入与目标序列所有可能的对齐方式的概率求和,从而生成一个相对于每个输入节点都可微分的损失值。输入与目标的对齐被假定为“多对一”关系,这将目标序列的长度限制为必须 \leq 输入长度。

参数
  • blank (int, optional) – 空格标签。默认值 00

  • reduction (str, optional) – 指定要应用于输出的归约方式:'none' | 'mean' | 'sum''none':不应用任何归约;'mean':输出损失将除以目标长度,然后取批次平均值;'sum':输出损失将被求和。默认值:'mean'

  • zero_infinity (bool, optional) – 是否将无限损失及其相关梯度清零。默认值:False 当输入序列过短无法与目标序列对齐时,主要会出现无限损失。

形状
  • Log_probs:形状为 (T,N,C)(T, N, C)(T,C)(T, C) 的 Tensor,其中 T=输入长度T = \text{输入长度}N=批次大小N = \text{批次大小},而 C=类别数量(包括空格)C = \text{类别数量(包括空格)}。输出的对数化概率(例如,使用 torch.nn.functional.log_softmax() 获得)。

  • Targets:形状为 (N,S)(N, S)(sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 的 Tensor,其中 N=批次大小N = \text{批次大小};如果形状是 (N,S)(N, S),则 S=最大目标长度S = \text{最大目标长度}。它表示目标序列。目标序列中的每个元素都是一个类别索引。并且目标索引不能是空格标签(默认值 = 0)。在 (N,S)(N, S) 形式中,目标序列将被填充到最长序列的长度,并堆叠在一起。在 (sum(target_lengths))(\operatorname{sum}(\text{target\_lengths})) 形式中,假定目标序列未填充并在一维中连接。

  • Input_lengths:形状为 (N)(N)()() 的 Tuple 或 Tensor,其中 N=批次大小N = \text{批次大小}。它表示输入的长度(每个长度必须 T\leq T)。并且为每个序列指定了长度,以便在假设序列已填充到相等长度的情况下实现掩码。

  • Target_lengths:形状为 (N)(N)()() 的 Tuple 或 Tensor,其中 N=批次大小N = \text{批次大小}。它表示目标的长度。为每个序列指定了长度,以便在假设序列已填充到相等长度的情况下实现掩码。如果目标形状是 (N,S)(N,S),则 target_lengths 实际上是每个目标序列的停止索引 sns_n,使得对于批次中的每个目标,有 target_n = targets[n,0:s_n]。每个长度必须 S\leq S。如果目标序列作为一维 Tensor 给出,该 Tensor 是各个目标序列的拼接,则 target_lengths 的总和必须等于该 Tensor 的总长度。

  • 输出:如果 reduction'mean'(默认值)或 'sum',则为标量。如果 reduction'none',则如果输入是批次处理的,形状为 (N)(N);如果输入未进行批次处理,形状为 ()(),其中 N=批次大小N = \text{批次大小}

示例

>>> # Target are to be padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>> S = 30      # Target sequence length of longest target in batch (padding length)
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded and unbatched (effectively N=1)
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>>
>>> # Initialize random batch of input vectors, for *size = (T,C)
>>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_()
>>> input_lengths = torch.tensor(T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
参考

A. Graves 等人:连接主义时间分类:使用循环神经网络标记未分段序列数据 (https://www.cs.toronto.edu/~graves/icml_2006.pdf)

注意

为了使用 CuDNN,必须满足以下条件:targets 必须采用拼接格式,所有 input_lengths 必须等于 Tblank=0blank=0target_lengths 256\leq 256,整型参数的 dtype 必须是 torch.int32

常规实现使用(在 PyTorch 中更常见)的 torch.long dtype。

注意

在使用 CUDA 后端和 CuDNN 的某些情况下,此算子可能会选择非确定性算法以提高性能。如果这是不可取的,您可以通过设置 torch.backends.cudnn.deterministic = True 来尝试使操作具有确定性(可能会牺牲性能)。背景信息请参阅可复现性相关的注意事项。

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获取问题解答

查看资源