快捷方式

local_kv_cache

torchtune.modules.common_utils.local_kv_cache(model: Module, *, batch_size: int, device: device, dtype: dtype, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None) Generator[None, None, None][source]

这个上下文管理器在给定的模型上临时启用 KV 缓存,该模型尚未设置 KV 缓存。在此上下文管理器中使用模型进行的所有前向传播都将使用 KV 缓存。

进入上下文管理器时,将使用给定的 batch_sizedtypemax_seq_len 设置 KV 缓存,并在退出时删除。

示例

>>> from torchtune.models.llama3_2 import llama3_2_1b
>>> from torchtune.modules import local_kv_cache
>>> import torch
>>> model = llama3_2_1b()
>>> print(model.caches_are_setup())
False
>>> print(model.caches_are_enabled())
False
>>> print(model.layers[0].attn.kv_cache)
None
>>> # entering cacheing mode
>>> with local_kv_cache(model,
>>>                     batch_size=1,
>>>                     device=torch.device("cpu"),
>>>                     dtype=torch.float32,
>>>                     decoder_max_seq_len=1024):
>>>     print(model.caches_are_setup())
True
>>>     print(model.caches_are_enabled())
True
>>>     print(model.layers[0].attn.kv_cache)
KVCache()
>>> # exited cacheing mode
>>> print(model.caches_are_setup())
False
>>> print(model.caches_are_enabled())
False
>>> print(model.layers[0].attn.kv_cache)
None
参数:
  • model (nn.Module) – 为其启用 KV 缓存的模型。

  • batch_size (int) – 缓存的批大小。

  • device (torch.device) – 设置缓存的设备。这应该与模型所在的设备相同。

  • dtype (torch.dpython:type) – 缓存的数据类型。

  • encoder_max_seq_len (Optional[int]) – 最大编码器缓存序列长度。

  • decoder_max_seq_len (Optional[int]) – 最大解码器缓存序列长度。

生成:

None – 返回控制权给调用者,并在给定模型上设置和启用 KV 缓存。

抛出异常:

ValueError – 如果模型已设置缓存。您可以使用 delete_kv_caches() 来删除现有缓存。

文档

访问全面的 PyTorch 开发者文档

查看文档

教程

获取适合初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源