快捷方式

KVCache

class torchtune.modules.KVCache(batch_size: int, max_seq_len: int, num_heads: int, head_dim: int, dtype: dtype)[源代码]

包含用于在推理期间缓存过去键和值的 kv 缓存的独立 nn.Module

参数:
  • batch_size (int) – 模型将使用的批次大小

  • max_seq_len (int) – 模型将使用的最大序列长度

  • num_heads (int) – 头的数量。我们使用 num_heads 而不是 num_kv_heads,因为缓存是在我们将键和值张量扩展为与查询张量具有相同形状后创建的。有关更多详细信息,请参阅 attention.py

  • head_dim (int) – 每个注意力头的嵌入维度

  • dtype (torch.dpython:type) – 缓存的数据类型

reset() None[源代码]

将缓存重置为零。

update(k_val: Tensor, v_val: Tensor) Tuple[Tensor, Tensor][源代码]

使用新的 k_valv_val 更新 KV 缓存并返回更新后的缓存。

注意

更新 KV 缓存时,假设后续更新应更新连续序列位置中的键值位置。如果要更新已经填充的缓存值,请使用 .reset(),它会将缓存重置到第零个位置。

示例

>>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16)
>>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
>>> cache.update(keys, values)
>>> # now positions 0 through 7 are filled
>>> cache.size
>>> 8
>>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32))
>>> cache.update(keys, values)
>>> # this will fill at position 8
>>> cache.size
>>> 9
参数:
  • k_val (torch.Tensor) – 形状为 [B, H, S, D] 的当前键张量

  • v_val (torch.Tensor) – 形状为 [B, H, S, D] 的当前值张量

返回值:

分别更新的键和值缓存张量。

返回类型:

Tuple[torch.Tensor, torch.Tensor]

引发:
  • ValueError – 如果 k_val 的序列长度超过最大缓存序列长度。

  • ValueError – 如果新键(或值)张量的批次大小大于缓存设置期间使用的批次大小。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源