KVCache¶
- class torchtune.modules.KVCache(batch_size: int, max_seq_len: int, num_kv_heads: int, head_dim: int, dtype: dtype)[source]¶
一个独立的
nn.Module
,包含一个 KV 缓存,用于在推理期间缓存过去的键和值。- 参数:
- update(k_val: Tensor, v_val: Tensor) Tuple[Tensor, Tensor] [source]¶
使用新的
k_val
,v_val
更新 KV 缓存并返回更新后的缓存。注意
更新 KV 缓存时,假定后续更新应更新连续序列位置中的键值位置。如果您希望更新已填充的缓存值,请使用
.reset()
,它会将缓存重置到第零个位置。示例
>>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16) >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32)) >>> cache.update(keys, values) >>> # now positions 0 through 7 are filled >>> cache.size >>> 8 >>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32)) >>> cache.update(keys, values) >>> # this will fill at position 8 >>> cache.size >>> 9
- 参数:
k_val (torch.Tensor) – 当前键张量,形状为 [B, H, S, D]
v_val (torch.Tensor) – 当前值张量,形状为 [B, H, S, D]
- 返回:
分别返回更新后的键和值缓存张量。
- 返回类型:
Tuple[torch.Tensor, torch.Tensor]
- 抛出:
ValueError – 如果新键(或值)张量的批处理大小大于缓存设置期间使用的批处理大小。
注意
- 如果
k_val
的序列长度超过最大缓存序列长度,此函数将抛出AssertionError
。 is longer than the maximum cache sequence length.