• 文档 >
  • 导出 tensordict 模块
快捷方式

导出 tensordict 模块

作者: Vincent Moens

前提条件

建议先阅读 TensorDictModule 教程,以便充分理解本教程。

一旦使用 tensordict.nn 编写模块后,通常需要隔离计算图并导出该图。 这样做的目的可能是为了在硬件(例如,机器人、无人机、边缘设备)上执行模型,或者完全消除对 tensordict 的依赖。

PyTorch 提供了多种导出模块的方法,包括 onnxtorch.export,两者都与 tensordict 兼容。

在本简短教程中,我们将了解如何使用 torch.export 来隔离模型的计算图。 torch.onnx 的支持遵循相同的逻辑。

主要学习内容

  • 在没有 TensorDict 输入的情况下执行 tensordict.nn 模块;

  • 选择模型的输出;

  • 处理随机模型;

  • 使用 torch.export 导出此类模型;

  • 将模型保存到文件;

  • 隔离 pytorch 模型;

import time

import torch
from tensordict.nn import (
    InteractionType,
    NormalParamExtractor,
    ProbabilisticTensorDictModule as Prob,
    set_interaction_type,
    TensorDictModule as Mod,
    TensorDictSequential as Seq,
)
from torch import distributions as dists, nn

设计模型

在许多应用中,使用随机模型非常有用,即输出变量不是确定性定义的模型,而是根据参数分布进行采样的模型。 例如,生成式 AI 模型在提供相同输入时通常会生成不同的输出,因为它们根据分布对输出进行采样,而分布的参数由输入定义。

tensordict 库通过 ProbabilisticTensorDictModule 类来处理这个问题。 这个原语是使用分布类(在本例中为 Normal)和将在执行时用于构建该分布的输入键的指示符来构建的。

因此,我们正在构建的网络将是三个主要组件的组合

  • 将输入映射到潜在参数的网络;

  • 一个 tensordict.nn.NormalParamExtractor 模块,将输入拆分为位置 “loc”“scale” 参数,以便传递给 Normal 分布;

  • 分布构造器模块。

model = Seq(
    # 1. A small network for embedding
    Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
    Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
    Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
    # 2. Extracting params
    Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
    # 3. Probabilistic module
    Prob(
        in_keys=["loc", "scale"],
        out_keys=["sample"],
        distribution_class=dists.Normal,
    ),
)

让我们运行这个模型,看看输出是什么样子

x = torch.randn(1, 3)
print(model(x=x))
(tensor([[0.0000, 0.2604, 0.0000, 0.0000]], grad_fn=<ReluBackward0>), tensor([[-0.1580, -0.5222, -0.3319,  0.5519]], grad_fn=<AddmmBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>), tensor([[0.8046, 1.3804]], grad_fn=<ClampMinBackward0>), tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>))

正如预期的那样,使用张量输入运行模型会返回与模块的输出键一样多的张量! 对于大型模型,这可能会非常烦人且浪费。 稍后,我们将看到如何限制模型的输出数量以解决此问题。

torch.exportTensorDictModule 一起使用

现在我们已经成功构建了模型,我们希望将模型的计算图提取到一个独立于 tensordict 的单个对象中。 torch.export 是一个 PyTorch 模块,专门用于隔离模块的图并以标准化方式表示它。 它的主要入口点是 export(),它返回一个 ExportedProgram 对象。 反过来,此对象有几个我们将在下面探讨的感兴趣的属性:一个 graph_module,它表示由 export 捕获的 FX 图,一个 graph_signature,其中包含图的输入、输出等,最后是一个 module(),它返回一个可调用对象,该对象可以代替原始模块使用。

虽然我们的模块接受 args 和 kwargs,但我们将重点关注其与 kwargs 的用法,因为这更清晰。

from torch.export import export

model_export = export(model, args=(), kwargs={"x": x})

让我们看一下模块

print("module:", model_export.module())
module: GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0]
    getitem_3 = broadcast_tensors[1];  broadcast_tensors = None
    return pytree.tree_unflatten((relu, linear_1, getitem_2, getitem_3, getitem_2), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

此模块可以像我们的原始模块一样运行(开销更低)

t0 = time.time()
model(x=x)
print(f"Time for TDModule: {(time.time()-t0)*1e6: 4.2f} micro-seconds")
exported = model_export.module()

# Exported version
t0 = time.time()
exported(x=x)
print(f"Time for exported module: {(time.time()-t0)*1e6: 4.2f} micro-seconds")
Time for TDModule:  469.45 micro-seconds
Time for exported module:  340.70 micro-seconds

以及 FX 图

print("fx graph:", model_export.graph_module.print_readable())
class GraphModule(torch.nn.Module):
    def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
         # File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
        linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias);  x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
        relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear);  linear = None
        linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias);  p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
        split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
        getitem: "f32[1, 2]" = split[0]
        getitem_1: "f32[1, 2]" = split[1];  split = None

         # File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
        add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
        softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add);  add = None
        add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
        clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None

         # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
        broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
        getitem_2: "f32[1, 2]" = broadcast_tensors[0]
        getitem_3: "f32[1, 2]" = broadcast_tensors[1];  broadcast_tensors = None
        return (relu, linear_1, getitem_2, getitem_3, getitem_2)

fx graph: class GraphModule(torch.nn.Module):
    def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
         # File: /pytorch/tensordict/tensordict/nn/common.py:1010 in _call_module, code: out = self.module(*tensors, **kwargs)
        linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias);  x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
        relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear);  linear = None
        linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias);  p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:129 in forward, code: loc, scale = tensor.chunk(2, -1)
        split = torch.ops.aten.split.Tensor(linear_1, 2, -1)
        getitem: "f32[1, 2]" = split[0]
        getitem_1: "f32[1, 2]" = split[1];  split = None

         # File: /pytorch/tensordict/tensordict/nn/utils.py:68 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
        add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
        softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add);  add = None
        add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None

         # File: /pytorch/tensordict/tensordict/nn/distributions/continuous.py:130 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
        clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None

         # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:55 in broadcast_all, code: return torch.broadcast_tensors(*values)
        broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
        getitem_2: "f32[1, 2]" = broadcast_tensors[0]
        getitem_3: "f32[1, 2]" = broadcast_tensors[1];  broadcast_tensors = None
        return (relu, linear_1, getitem_2, getitem_3, getitem_2)

使用嵌套键

嵌套键是 tensordict 库的核心功能,因此能够导出读取和写入嵌套条目的模块是一项需要支持的重要功能。 由于关键字参数必须是常规字符串,因此 dispatch 无法直接使用它们。 相反,dispatch 将解包用常规下划线 (“_”) 连接的嵌套键,如下例所示。

model_nested = Seq(
    Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
    Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))

model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())
exported module with nested input: GraphModule()



def forward(self, some_key):
    some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec)
    add = torch.ops.aten.add.Tensor(some_key, 1);  some_key = None
    sub = torch.ops.aten.sub.Tensor(add, 1);  add = None
    return pytree.tree_unflatten((sub,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

请注意,module() 返回的可调用对象是一个纯 python 可调用对象,可以使用 compile() 对其进行编译。

保存导出的模块

torch.export 有自己的序列化协议,save()load()。 按照惯例,应使用 “.pt2” 扩展名

>>> torch.export.save(model_export, "model.pt2")

选择输出

回想一下,tensordict.nn 会将每个中间值保留在输出中,除非用户明确要求仅保留特定值。 在训练期间,这可能非常有用:可以轻松记录图的中间值,或将其用于其他目的(例如,基于其保存的参数重建分布,而不是保存 Distribution 对象本身)。 还可以认为,在训练期间,注册中间值对内存的影响可以忽略不计,因为它们是 torch.autograd 用于计算参数梯度的计算图的一部分。

但是,在推理期间,我们最有可能只对模型的最终样本感兴趣。 因为我们希望提取模型以用于独立于 tensordict 库的用途,所以隔离我们想要的唯一输出是有意义的。 为此,我们有几个选项

  1. 使用 selected_out_keys 关键字参数构建 TensorDictSequential(),这将导致在调用模块期间选择所需的条目;

  2. 使用 select_out_keys() 方法,该方法将就地修改 out_keys 属性(可以通过 reset_out_keys() 恢复)。

  3. 将现有实例包装在 TensorDictSequential() 中,它将过滤掉不需要的键

    >>> module_filtered = Seq(module, selected_out_keys=["sample"])
    

让我们在选择模型的输出键后测试模型。 当提供 x 输入时,我们希望我们的模型输出一个与分布样本相对应的单个张量

model.select_out_keys("sample")
print(model(x=x))
tensor([[-0.1580, -0.5222]], grad_fn=<SplitBackward0>)

我们看到输出现在是一个与分布样本相对应的单个张量。 我们可以由此创建一个新的导出图。 它的计算图应该被简化

model_export = export(model, args=(), kwargs={"x": x})
print("module:", model_export.module())
module: GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  relu = module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1);  linear_1 = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0];  broadcast_tensors = None
    return pytree.tree_unflatten((getitem_2,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

控制采样策略

我们尚未讨论 ProbabilisticTensorDictModule 如何从分布中采样。 通过采样,我们指的是根据特定策略获取分布定义的空间内的值。 例如,人们可能希望在训练期间获得随机样本,但在推理时获得确定性样本(例如,均值或众数)。 为了解决这个问题,tensordict 利用 set_interaction_type 装饰器和上下文管理器,它们接受 InteractionType 枚举输入

>>> with set_interaction_type(InteractionType.MEAN):
...     output = module(input)  # takes the input of the distribution, if ProbabilisticTensorDictModule is invoked

默认的 InteractionTypeInteractionType.DETERMINISTIC,如果未直接实现,则对于实数域的分布是均值,对于离散域的分布是众数。 可以使用 ProbabilisticTensorDictModuledefault_interaction_type 关键字参数更改此默认值。

让我们回顾一下:为了控制我们网络的采样策略,我们可以在构造函数中定义默认采样策略,或者在运行时通过 set_interaction_type 上下文管理器覆盖它。

正如我们可以从以下示例中看到的那样,torch.export 正确响应了装饰器的用法:如果我们要求随机样本,则输出与我们要求均值时的输出不同

with set_interaction_type(InteractionType.RANDOM):
    model_export = export(model, args=(), kwargs={"x": x})
    print(model_export.module())

with set_interaction_type(InteractionType.MEAN):
    model_export = export(model, args=(), kwargs={"x": x})
    print(model_export.module())
GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  relu = module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1);  linear_1 = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0]
    getitem_3 = broadcast_tensors[1];  broadcast_tensors = None
    empty = torch.ops.aten.empty.memory_format([1, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
    normal_functional = torch.ops.aten.normal_functional.default(empty);  empty = None
    mul = torch.ops.aten.mul.Tensor(normal_functional, getitem_3);  normal_functional = getitem_3 = None
    add_2 = torch.ops.aten.add.Tensor(getitem_2, mul);  getitem_2 = mul = None
    return pytree.tree_unflatten((add_2,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
  (module): Module(
    (0): Module(
      (module): Module()
    )
    (2): Module(
      (module): Module()
    )
  )
)



def forward(self, x):
    x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
    module_0_module_weight = getattr(self.module, "0").module.weight
    module_0_module_bias = getattr(self.module, "0").module.bias
    module_2_module_weight = getattr(self.module, "2").module.weight
    module_2_module_bias = getattr(self.module, "2").module.bias
    linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias);  x = module_0_module_weight = module_0_module_bias = None
    relu = torch.ops.aten.relu.default(linear);  linear = None
    linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias);  relu = module_2_module_weight = module_2_module_bias = None
    split = torch.ops.aten.split.Tensor(linear_1, 2, -1);  linear_1 = None
    getitem = split[0]
    getitem_1 = split[1];  split = None
    add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335);  getitem_1 = None
    softplus = torch.ops.aten.softplus.default(add);  add = None
    add_1 = torch.ops.aten.add.Tensor(softplus, 0.01);  softplus = None
    clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001);  add_1 = None
    broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]);  getitem = clamp_min = None
    getitem_2 = broadcast_tensors[0];  broadcast_tensors = None
    return pytree.tree_unflatten((getitem_2,), self._out_spec)

# To see more debug info, please use `graph_module.print_readable()`

这就是使用 torch.export 所需了解的全部内容。 请参阅 官方文档 以获取更多信息。

后续步骤和进一步阅读

  • 查看 torch.export 教程,可在此处 获取

  • ONNX 支持:查看 ONNX 教程 以了解有关此功能的更多信息。 导出到 ONNX 与此处解释的 torch.export 非常相似。

  • 为了在没有 python 环境的服务器上部署 PyTorch 代码,请查看 AOTInductor 文档。

脚本总运行时间: (0 分 1.695 秒)

由 Sphinx-Gallery 生成的图库

文档

访问 PyTorch 的综合开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源