ForwardKLWithChunkedOutputLoss¶
- class torchtune.modules.loss.ForwardKLWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]¶
具有分块输出的前向 KL,通过一次仅提升一个分块来节省内存。
由于模型使用 bf16 训练,在计算 KL 散度之前,我们必须将其提升到 fp32 以获得更好的准确性和稳定性。当提升发生时,内存使用量会加倍。像 llama3 这样的模型具有较大的词汇量,因此具有较大的输出结果 (bsz, num_tokens, vocab_size)。如果我们在 token 级别进行分块,则仍然可以正常计算交叉熵,但一次仅提升一个分块可以节省大量的内存。
- 参数:
- forward(student_logits: List[Tensor], teacher_logits: List[Tensor], labels: Tensor) Tensor [source]¶
- 参数:
student_logits (List[torch.Tensor]) – 来自学生模型的分块 logits 列表,长度为
self.num_output_chunks
,其中每个分块的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)。teacher_logits (List[torch.Tensor]) – 来自教师模型的分块 logits 列表,长度为
self.num_output_chunks
,其中每个分块的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)。labels (torch.Tensor) – 形状为 (batch_size, num_tokens) 的真实标签。
- 返回值:
形状为 (1,) 的 KL 散度损失。
- 返回类型:
示例
>>> loss_fn = ForwardKLWithChunkedOutputLoss() >>> >>> h = torch.tensor([bsz, num_tokens, dim]) >>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] >>> teacher_chunks = [teacher_model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] >>> labels = torch.tensor([bsz, num_tokens]) >>> loss = loss_fn(output_chunks, teacher_chunks, labels)