local_kv_cache¶
- torchtune.modules.common_utils.local_kv_cache(model: Module, *, batch_size: int, device: device, dtype: dtype, encoder_max_seq_len: Optional[int] = None, decoder_max_seq_len: Optional[int] = None) Generator[None, None, None][source]¶
- 这个上下文管理器在给定的模型上临时启用 KV 缓存,该模型尚未设置 KV 缓存。在此上下文管理器中使用模型进行的所有前向传播都将使用 KV 缓存。 - 进入上下文管理器时,将使用给定的 - batch_size、- dtype和- max_seq_len设置 KV 缓存,并在退出时删除。- 示例 - >>> from torchtune.models.llama3_2 import llama3_2_1b >>> from torchtune.modules import local_kv_cache >>> import torch >>> model = llama3_2_1b() >>> print(model.caches_are_setup()) False >>> print(model.caches_are_enabled()) False >>> print(model.layers[0].attn.kv_cache) None >>> # entering cacheing mode >>> with local_kv_cache(model, >>> batch_size=1, >>> device=torch.device("cpu"), >>> dtype=torch.float32, >>> decoder_max_seq_len=1024): >>> print(model.caches_are_setup()) True >>> print(model.caches_are_enabled()) True >>> print(model.layers[0].attn.kv_cache) KVCache() >>> # exited cacheing mode >>> print(model.caches_are_setup()) False >>> print(model.caches_are_enabled()) False >>> print(model.layers[0].attn.kv_cache) None - 参数:
- model (nn.Module) – 为其启用 KV 缓存的模型。 
- batch_size (int) – 缓存的批大小。 
- device (torch.device) – 设置缓存的设备。这应该与模型所在的设备相同。 
- dtype (torch.dpython:type) – 缓存的数据类型。 
- encoder_max_seq_len (Optional[int]) – 最大编码器缓存序列长度。 
- decoder_max_seq_len (Optional[int]) – 最大解码器缓存序列长度。 
 
- 生成:
- None – 返回控制权给调用者,并在给定模型上设置和启用 KV 缓存。 
- 抛出异常:
- ValueError – 如果模型已设置缓存。您可以使用 - delete_kv_caches()来删除现有缓存。