CEWithChunkedOutputLoss¶
- class torchtune.modules.loss.CEWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]¶
带有分块输出的交叉熵损失,通过每次仅向上转换一个分块来节省内存。
当模型使用 bf16 训练时,在运行 CE 之前,我们需要将其向上转换为 fp32 以获得更好的准确性和稳定性。向上转换发生时,内存使用量会翻倍。像 llama3 这样的模型词汇量很大,因此具有形状为
(bsz, num_tokens, vocab_size)
的大型输出张量。如果我们对 token 进行分块,仍然可以正常计算交叉熵,但每次只向上转换一个分块可以节省大量内存。CE 和向上转换必须一起编译以获得更好的性能。使用此类时,我们建议仅对方法
torch.compile()
进行编译。如果您编译整个类,则不会实现分块带来的收益。更多详细信息,请参考:https://github.com/pytorch/torchtune/pull/1390
- compute_cross_entropy(logits: Tensor, labels: Tensor, normalize: bool = True) Tensor [source]¶
将 logits 向上转换为 fp32 并计算交叉熵损失。
- forward(logits: List[Tensor], labels: Tensor) Tensor [source]¶
- 参数:
logits (List[torch.Tensor]) – 分块 logits 的列表,长度为
self.num_output_chunks
,其中每个分块的形状为(batch_size, num_tokens / num_output_chunks, vocab_size)
。labels (torch.Tensor) – 形状为
(batch_size, num_tokens)
的真实标签。
- 返回:
形状为 (1,) 的交叉熵损失。
- 返回类型:
示例
>>> loss_fn = ChunkedCrossEntropyLoss() >>> >>> h = torch.tensor([bsz, num_tokens, dim]) >>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] >>> >>> labels = torch.tensor([bsz, num_tokens]) >>> loss = loss_fn(output_chunks, labels)