快捷方式

CEWithChunkedOutputLoss

class torchtune.modules.loss.CEWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]

带有分块输出的交叉熵损失,通过每次仅向上转换一个分块来节省内存。

当模型使用 bf16 训练时,在运行 CE 之前,我们需要将其向上转换为 fp32 以获得更好的准确性和稳定性。向上转换发生时,内存使用量会翻倍。像 llama3 这样的模型词汇量很大,因此具有形状为 (bsz, num_tokens, vocab_size) 的大型输出张量。如果我们对 token 进行分块,仍然可以正常计算交叉熵,但每次只向上转换一个分块可以节省大量内存。

CE 和向上转换必须一起编译以获得更好的性能。使用此类时,我们建议仅对方法 torch.compile() 进行编译。如果您编译整个类,则不会实现分块带来的收益。

更多详细信息,请参考:https://github.com/pytorch/torchtune/pull/1390

compute_cross_entropy(logits: Tensor, labels: Tensor, normalize: bool = True) Tensor[source]

将 logits 向上转换为 fp32 并计算交叉熵损失。

forward(logits: List[Tensor], labels: Tensor) Tensor[source]
参数
  • logits (List[torch.Tensor]) – 分块 logits 的列表,长度为 self.num_output_chunks,其中每个分块的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)

  • labels (torch.Tensor) – 形状为 (batch_size, num_tokens) 的真实标签。

返回

形状为 (1,) 的交叉熵损失。

返回类型

torch.Tensor

示例

>>> loss_fn = ChunkedCrossEntropyLoss()
>>>
>>> h = torch.tensor([bsz, num_tokens, dim])
>>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>>
>>> labels = torch.tensor([bsz, num_tokens])
>>> loss = loss_fn(output_chunks, labels)

文档

获取 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获取问题解答

查看资源