快捷方式

带分块输出损失的交叉熵

class torchtune.modules.loss.CEWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[源代码]

具有分块输出的交叉熵,通过一次仅上采一个分块来节省内存。

当模型使用 bf16 训练时,在运行 CE 之前,我们必须将其上采到 fp32,以获得更好的准确性和稳定性。当上采发生时,内存使用量会翻倍。像 llama3 这样的模型具有很大的词汇量,因此具有形状为 (bsz, num_tokens, vocab_size) 的大型输出张量。如果我们对标记级别进行分块,则仍然可以正常计算交叉熵,但一次仅上采一个分块可以节省大量内存。

为了获得更好的性能,CE 和上采必须一起编译。使用此类时,我们建议仅对 compute_cross_entropy 方法使用 torch.compile()。如果编译整个类,则不会实现分块带来的优势。

有关更多详细信息,请参考:https://github.com/pytorch/torchtune/pull/1390

compute_cross_entropy(logits: Tensor, labels: Tensor) Tensor[源代码]

将 logits 上采到 fp32 并计算交叉熵损失。

forward(logits: List[Tensor], labels: Tensor) Tensor[源代码]
参数:
  • logits (List[torch.Tensor]) – 长度为 self.num_output_chunks 的分块 logits 列表,其中每个分块的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)

  • labels (torch.Tensor) – 形状为 (batch_size, num_tokens) 的真实标签。

返回值:

形状为 (1,) 的交叉熵损失。

返回类型:

torch.Tensor

示例

>>> loss_fn = ChunkedCrossEntropyLoss()
>>>
>>> h = torch.tensor([bsz, num_tokens, dim])
>>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>>
>>> labels = torch.tensor([bsz, num_tokens])
>>> loss = loss_fn(output_chunks, labels)

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源