快捷方式

ForwardKLWithChunkedOutputLoss

class torchtune.modules.loss.ForwardKLWithChunkedOutputLoss(num_output_chunks: int = 8, ignore_index: int = - 100)[source]

带有分块输出的前向 KL 损失,通过每次只向上转换一个块来节省内存。

由于模型使用 bf16 训练,在计算 KL 散度之前,我们必须将其向上转换为 fp32,以获得更好的精度和稳定性。当向上转换发生时,内存使用量会加倍。像 llama3 这样的模型具有很大的词汇表大小,因此输出结果也很大(bsz,num_tokens,vocab_size)。如果我们在 token 级别进行分块,您仍然可以正常计算交叉熵,但每次只向上转换一个块可以节省大量内存。

参数:
  • num_output_chunks (int) – 将输出分成的块数。每个块的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)。默认值: 8

  • ignore_index (int) – 指定一个被忽略且不参与输入梯度的目标值。损失将除以未被忽略的目标数量。默认值: -100

forward(student_logits: List[Tensor], teacher_logits: List[Tensor], labels: Tensor) Tensor[source]
参数:
  • student_logits (List[torch.Tensor]) – 来自学生模型的 chunked logit 列表,长度为 self.num_output_chunks,其中每个 chunk 的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)。

  • teacher_logits (List[torch.Tensor]) – 来自教师模型的 chunked logit 列表,长度为 self.num_output_chunks,其中每个 chunk 的形状为 (batch_size, num_tokens / num_output_chunks, vocab_size)。

  • labels (torch.Tensor) – 真实标签,形状为 (batch_size, num_tokens)。

返回值:

形状为 (1,) 的 KL 散度损失。

返回类型:

torch.Tensor

示例

>>> loss_fn = ForwardKLWithChunkedOutputLoss()
>>>
>>> h = torch.tensor([bsz, num_tokens, dim])
>>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>> teacher_chunks = [teacher_model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)]
>>> labels = torch.tensor([bsz, num_tokens])
>>> loss = loss_fn(output_chunks, teacher_chunks, labels)

文档

查阅 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

寻找开发资源并获得问题解答

查看资源