快捷方式

CEWithChunkedOutputLoss

class torchtune.modules.loss.CEWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]

具有分块输出的交叉熵损失,通过一次仅向上转换一个块来节省内存。

每当模型使用 bf16 训练时,在运行 CE 之前,我们必须将其向上转换为 fp32,以获得更好的准确性和稳定性。当发生向上转换时,内存使用量会增加一倍。像 llama3 这样的模型具有很大的词汇量,因此具有形状为 (bsz, num_tokens, vocab_size) 的大型输出张量。如果我们按 token 级别进行分块,您仍然可以正常计算交叉熵,但一次只向上转换一个块可以节省大量内存。

CE 和向上转换必须编译在一起才能获得更好的性能。当使用此类时,我们建议仅对方法 compute_cross_entropy 使用 torch.compile()。如果您编译整个类,则无法实现分块带来的收益。

有关更多详细信息,请参阅:https://github.com/pytorch/torchtune/pull/1390

compute_cross_entropy(logits: Tensor, labels: Tensor, normalize: bool = True) Tensor[source]

将 logits 向上转换为 fp32 并计算交叉熵损失。

forward(logits: List[Tensor], labels: Tensor) Tensor[source]
参数:
  • logits (List[torch.Tensor]) – 长度为 self.num_output_chunks 的分块 logits 列表,其中每个块的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)

  • labels (torch.Tensor) – 形状为 (batch_size, num_tokens) 的真实标签。

返回:

形状为 (1,) 的交叉熵损失。

返回类型:

torch.Tensor

示例

>>> loss_fn = ChunkedCrossEntropyLoss()
>>>
>>> h = torch.tensor([bsz, num_tokens, dim])
>>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>>
>>> labels = torch.tensor([bsz, num_tokens])
>>> loss = loss_fn(output_chunks, labels)

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的问题解答

查看资源