disable_kv_cache¶
- torchtune.modules.common_utils.disable_kv_cache(model: Module) Generator[None, None, None] [源]¶
这个上下文管理器暂时禁用给定模型上的 KV-cache,该模型必须已经设置了 KV-cache。在此上下文管理器中使用模型进行的所有前向传播都不会使用 KV-cache。
进入上下文管理器时,KV-cache 将被禁用;退出时,KV-cache 将被重新启用,且不会被修改。
这在需要交替使用 KV-cache 和不使用 KV-cache 的模型调用场景中非常有用,无需每次都额外开销删除和重新设置缓存。
示例
>>> from torchtune.models.llama3_2 import llama3_2_1b >>> from torchtune.modules import disable_kv_cache >>> import torch >>> model = llama3_2_1b() >>> # setup caches >>> model.setup_caches(batch_size=1, >>> dtype=torch.float32, >>> decoder_max_seq_len=1024) >>> print(model.caches_are_setup()) True >>> print(model.caches_are_enabled()) True >>> print(model.layers[0].attn.kv_cache) KVCache() >>> # now temporarily disable caches >>> with disable_kv_cache(model): >>> print(model.caches_are_setup()) True >>> print(model.caches_are_enabled()) False >>> print(model.layers[0].attn.kv_cache) KVCache() >>> # caches are now re-enabled, and their state is untouched >>> print(model.caches_are_setup()) True >>> print(model.caches_are_enabled()) True >>> print(model.layers[0].attn.kv_cache) KVCache()
- 参数:
model (nn.Module) – 要禁用 KV-cache 的模型。
- 返回值:
None – 将控制权返回给调用方,同时在给定模型上禁用 KV-cache。
- 引发:
ValueError – 如果模型没有设置缓存。请先使用
setup_caches()
设置缓存。