带分块输出损失的交叉熵¶
- class torchtune.modules.loss.CEWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[源代码]¶
具有分块输出的交叉熵,通过一次仅上采一个分块来节省内存。
当模型使用 bf16 训练时,在运行 CE 之前,我们必须将其上采到 fp32,以获得更好的准确性和稳定性。当上采发生时,内存使用量会翻倍。像 llama3 这样的模型具有很大的词汇量,因此具有形状为
(bsz, num_tokens, vocab_size)
的大型输出张量。如果我们对标记级别进行分块,则仍然可以正常计算交叉熵,但一次仅上采一个分块可以节省大量内存。为了获得更好的性能,CE 和上采必须一起编译。使用此类时,我们建议仅对
compute_cross_entropy
方法使用torch.compile()
。如果编译整个类,则不会实现分块带来的优势。有关更多详细信息,请参考:https://github.com/pytorch/torchtune/pull/1390
- forward(logits: List[Tensor], labels: Tensor) Tensor [源代码]¶
- 参数:
logits (List[torch.Tensor]) – 长度为
self.num_output_chunks
的分块 logits 列表,其中每个分块的形状为(batch_size, num_tokens / num_output_chunks, vocab_size)
。labels (torch.Tensor) – 形状为
(batch_size, num_tokens)
的真实标签。
- 返回值:
形状为 (1,) 的交叉熵损失。
- 返回类型:
示例
>>> loss_fn = ChunkedCrossEntropyLoss() >>> >>> h = torch.tensor([bsz, num_tokens, dim]) >>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] >>> >>> labels = torch.tensor([bsz, num_tokens]) >>> loss = loss_fn(output_chunks, labels)