快捷方式

WandaSparsifier

class torchao.sparsity.WandaSparsifier(sparsity_level: float = 0.5, semi_structured_block_size: Optional[int] = None)[source]

Wanda 稀疏器

Wanda(通过权重和激活进行剪枝),在 https://arxiv.org/abs/2306.11695 中提出,是一种激活感知剪枝方法。该稀疏器基于输入激活范数和权重幅度的乘积来移除权重。

此稀疏器由三个变量控制:1. sparsity_level 定义了被置零的稀疏块的数量;

参数:
  • sparsity_level – 目标稀疏度;

  • model – 要稀疏化的模型;

prepare(model: Module, config: List[Dict]) None[source]

通过添加参数化来准备模型。

注意

The model is modified inplace. If you need to preserve the original
model, use copy.deepcopy.
squash_mask(params_to_keep: Optional[Tuple[str, ...]] = None, params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, *args, **kwargs)[source]

将稀疏掩码压缩到适当的张量中。

如果设置了 params_to_keepparams_to_keep_per_layer 中的任何一个,则模块将附加一个 sparse_params 字典。

参数:
  • params_to_keep – 要保存在模块中的键的列表,或表示将保存稀疏参数的模块和键的字典

  • params_to_keep_per_layer – 用于指定应为特定层保存的参数的字典。字典中的键应为模块 fqn,而值应为字符串列表,其中包含要在 sparse_params 中保存的变量的名称

示例

>>> # xdoctest: +SKIP("locals are undefined")
>>> # Don't save any sparse params
>>> sparsifier.squash_mask()
>>> hasattr(model.submodule1, 'sparse_params')
False
>>> # Keep sparse params per layer
>>> sparsifier.squash_mask(
...     params_to_keep_per_layer={
...         'submodule1.linear1': ('foo', 'bar'),
...         'submodule2.linear42': ('baz',)
...     })
>>> print(model.submodule1.linear1.sparse_params)
{'foo': 42, 'bar': 24}
>>> print(model.submodule2.linear42.sparse_params)
{'baz': 0.1}
>>> # Keep sparse params for all layers
>>> sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
>>> print(model.submodule1.linear1.sparse_params)
{'foo': 42, 'bar': 24}
>>> print(model.submodule2.linear42.sparse_params)
{'foo': 42, 'bar': 24}
>>> # Keep some sparse params for all layers, and specific ones for
>>> # some other layers
>>> sparsifier.squash_mask(
...     params_to_keep=('foo', 'bar'),
...     params_to_keep_per_layer={
...         'submodule2.linear42': ('baz',)
...     })
>>> print(model.submodule1.linear1.sparse_params)
{'foo': 42, 'bar': 24}
>>> print(model.submodule2.linear42.sparse_params)
{'foo': 42, 'bar': 24, 'baz': 0.1}
update_mask(module: Module, tensor_name: str, sparsity_level: float, **kwargs) None[source]

WandaSparsifier 的剪枝函数

首先在 act_per_input 变量中检索激活统计信息。然后计算 Wanda 剪枝指标。然后通过比较当前层整体的这个指标来剪枝权重矩阵。

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源