torch.export¶
警告
此功能是正在积极开发中的原型,将来 WILL BE BREAKING CHANGES。
概述¶
torch.export.export()
接受任意 Python 可调用对象(torch.nn.Module
、函数或方法)并生成一个跟踪的图,该图仅以 Ahead-of-Time (AOT) 方式表示函数的张量计算,随后可以使用不同的输出执行或序列化。
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: torch.export.ExportedProgram = export(
Mod(), args=example_args
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 10], arg1_1: f32[10, 10]):
# code: a = torch.sin(x)
sin: f32[10, 10] = torch.ops.aten.sin.default(arg0_1);
# code: b = torch.cos(y)
cos: f32[10, 10] = torch.ops.aten.cos.default(arg1_1);
# code: return a + b
add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos);
return (add,)
Graph signature: ExportGraphSignature(
parameters=[],
buffers=[],
user_inputs=['arg0_1', 'arg1_1'],
user_outputs=['add'],
inputs_to_parameters={},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
torch.export
生成一个干净的中间表示 (IR),具有以下不变式。可以在 此处 找到有关 IR 的更多规范。
健全性: 保证它是原始程序的健全表示,并保持与原始程序相同的调用约定。
规范化: 图中没有 Python 语义。原始程序中的子模块被内联以形成一个完全扁平化的计算图。
图属性: 该图纯粹是函数式的,这意味着它不包含具有副作用的运算,例如变异或别名。它不会变异任何中间值、参数或缓冲区。
元数据: 该图包含在跟踪过程中捕获的元数据,例如用户代码的堆栈跟踪。
在幕后,torch.export
利用以下最新技术
TorchDynamo (torch._dynamo) 是一个内部 API,它使用称为 Frame Evaluation API 的 CPython 功能来安全地跟踪 PyTorch 图。这提供了一种极大地改进的图捕获体验,所需的重写次数要少得多才能完全跟踪 PyTorch 代码。
AOT Autograd 提供了一个功能化的 PyTorch 图,并确保该图被分解/降低到 ATen 运算符集。
Torch FX (torch.fx) 是该图的底层表示,允许灵活的基于 Python 的转换。
现有框架¶
torch.compile()
也利用与 torch.export
相同的 PT2 堆栈,但略有不同
JIT 与 AOT:
torch.compile()
是一个 JIT 编译器,它不打算用于生成部署之外的编译工件。部分与完整图捕获: 当
torch.compile()
遇到模型中无法跟踪的部分时,它将“图中断”并回退到在 eager Python 运行时中运行程序。相比之下,torch.export
旨在获取 PyTorch 模型的完整图表示,因此当遇到无法跟踪的内容时,它会报错。由于torch.export
生成与任何 Python 功能或运行时分离的完整图,因此该图可以保存在不同的环境和语言中并运行。可用性权衡: 由于
torch.compile()
能够在遇到无法跟踪的内容时回退到 Python 运行时,因此它更加灵活。相反,torch.export
将要求用户提供更多信息或重写代码以使其可跟踪。
与 torch.fx.symbolic_trace()
相比,torch.export
使用 TorchDynamo 进行跟踪,TorchDynamo 在 Python 字节码级别上运行,使其能够跟踪不受 Python 运算符重载支持的任意 Python 结构。此外,torch.export
精细地跟踪张量元数据,以便针对张量形状等条件不会导致跟踪失败。通常,torch.export
预计可以在更多用户程序上运行,并生成更低级别的图(在 torch.ops.aten
运算符级别)。请注意,用户仍然可以使用 torch.fx.symbolic_trace()
作为 torch.export
之前的预处理步骤。
与 torch.jit.script()
相比,torch.export
不会捕获 Python 控制流或数据结构,但它支持比 TorchScript 更多的 Python 语言特性(因为它更容易对 Python 字节码进行全面覆盖)。生成的图形更简单,只有直线控制流(除了显式控制流运算符)。
与 torch.jit.trace()
相比,torch.export
是健壮的:它能够追踪对大小执行整数计算的代码,并记录所有必要的旁路条件以证明特定追踪对其他输入有效。
导出 PyTorch 模型¶
示例¶
主要入口点是通过 torch.export.export()
,它接受一个可调用对象(torch.nn.Module
、函数或方法)和样本输入,并将计算图捕获到一个 torch.export.ExportedProgram
中。示例
import torch
from torch.export import export
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[16, 3, 3, 3], arg1_1: f32[16], arg2_1: f32[1, 3, 256, 256], arg3_1: f32[1, 16, 256, 256]):
# code: a = self.conv(x)
convolution: f32[1, 16, 256, 256] = torch.ops.aten.convolution.default(
arg2_1, arg0_1, arg1_1, [1, 1], [1, 1], [1, 1], False, [0, 0], 1
);
# code: a.add_(constant)
add: f32[1, 16, 256, 256] = torch.ops.aten.add.Tensor(convolution, arg3_1);
# code: return self.maxpool(self.relu(a))
relu: f32[1, 16, 256, 256] = torch.ops.aten.relu.default(add);
max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(
relu, [3, 3], [3, 3]
);
getitem: f32[1, 16, 85, 85] = max_pool2d_with_indices[0];
return (getitem,)
Graph signature: ExportGraphSignature(
parameters=['L__self___conv.weight', 'L__self___conv.bias'],
buffers=[],
user_inputs=['arg2_1', 'arg3_1'],
user_outputs=['getitem'],
inputs_to_parameters={
'arg0_1': 'L__self___conv.weight',
'arg1_1': 'L__self___conv.bias',
},
inputs_to_buffers={},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {}
检查 ExportedProgram
,我们可以注意到以下几点
The
torch.fx.Graph
包含原始程序的计算图,以及原始代码的记录,便于调试。该图仅包含 此处 找到的
torch.ops.aten
运算符和自定义运算符,并且是完全可用的,没有任何诸如torch.add_
之类的就地运算符。参数(卷积的权重和偏差)被提升为图的输入,导致图中没有
get_attr
节点,这些节点以前存在于torch.fx.symbolic_trace()
的结果中。The
torch.export.ExportGraphSignature
模拟输入和输出签名,并指定哪些输入是参数。图中每个节点生成的张量的结果形状和数据类型都已记录。例如,
convolution
节点将生成一个数据类型为torch.float32
,形状为 (1, 16, 256, 256) 的张量。
非严格导出¶
在 PyTorch 2.3 中,我们引入了名为 **非严格模式** 的新追踪模式。它仍在进行硬化,因此如果您遇到任何问题,请在 Github 上提交它们,并加上标签“oncall: export”。
在 *非严格模式* 中,我们使用 Python 解释器追踪程序。您的代码将与在急切模式下完全一样执行;唯一的区别是所有张量对象都将被 ProxyTensors 替换,这些 ProxyTensors 将记录所有操作到一个图中。
在 *严格* 模式下(当前是默认模式),我们首先使用 TorchDynamo(一个字节码分析引擎)追踪程序。TorchDynamo 不会真正执行您的 Python 代码。相反,它会对它进行符号分析,并根据结果构建一个图。这种分析允许 torch.export 提供更强的安全保证,但并非所有 Python 代码都受支持。
您可能想要使用非严格模式的情况的一个例子是,如果您遇到了不支持的 TorchDynamo 功能,而该功能可能不容易解决,并且您知道 Python 代码对于计算并不完全必要。例如
import contextlib
import torch
class ContextManager():
def __init__(self):
self.count = 0
def __enter__(self):
self.count += 1
def __exit__(self, exc_type, exc_value, traceback):
self.count -= 1
class M(torch.nn.Module):
def forward(self, x):
with ContextManager():
return x.sin() + x.cos()
export(M(), (torch.ones(3, 3),), strict=False) # Non-strict traces successfully
export(M(), (torch.ones(3, 3),)) # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager
在此示例中,使用非严格模式的第一个调用(通过 strict=False
标志)成功追踪,而使用严格模式的第二个调用(默认)导致失败,TorchDynamo 无法支持上下文管理器。一种选择是重写代码(参见 torch.export 的局限性),但由于上下文管理器不会影响模型中的张量计算,我们可以使用非严格模式的结果。
表达动态性¶
默认情况下,torch.export
将追踪程序,假设所有输入形状都是 **静态** 的,并将导出的程序专门化到这些维度。但是,某些维度(例如批处理维度)可以是动态的,并在每次运行时有所不同。这些维度必须使用 torch.export.Dim()
API 来创建它们,并通过 dynamic_shapes
参数传递给 torch.export.export()
。示例
import torch
from torch.export import Dim, export
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program: torch.export.ExportedProgram = export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[32, 64], arg1_1: f32[32], arg2_1: f32[64, 128], arg3_1: f32[64], arg4_1: f32[32], arg5_1: f32[s0, 64], arg6_1: f32[s0, 128]):
# code: out1 = self.branch1(x1)
permute: f32[64, 32] = torch.ops.aten.permute.default(arg0_1, [1, 0]);
addmm: f32[s0, 32] = torch.ops.aten.addmm.default(arg1_1, arg5_1, permute);
relu: f32[s0, 32] = torch.ops.aten.relu.default(addmm);
# code: out2 = self.branch2(x2)
permute_1: f32[128, 64] = torch.ops.aten.permute.default(arg2_1, [1, 0]);
addmm_1: f32[s0, 64] = torch.ops.aten.addmm.default(arg3_1, arg6_1, permute_1);
relu_1: f32[s0, 64] = torch.ops.aten.relu.default(addmm_1); addmm_1 = None
# code: return (out1 + self.buffer, out2)
add: f32[s0, 32] = torch.ops.aten.add.Tensor(relu, arg4_1);
return (add, relu_1)
Graph signature: ExportGraphSignature(
parameters=[
'branch1.0.weight',
'branch1.0.bias',
'branch2.0.weight',
'branch2.0.bias',
],
buffers=['L__self___buffer'],
user_inputs=['arg5_1', 'arg6_1'],
user_outputs=['add', 'relu_1'],
inputs_to_parameters={
'arg0_1': 'branch1.0.weight',
'arg1_1': 'branch1.0.bias',
'arg2_1': 'branch2.0.weight',
'arg3_1': 'branch2.0.bias',
},
inputs_to_buffers={'arg4_1': 'L__self___buffer'},
buffers_to_mutate={},
backward_signature=None,
assertion_dep_token=None,
)
Range constraints: {s0: RangeConstraint(min_val=2, max_val=9223372036854775806)}
还有一些其他需要注意的地方
通过
torch.export.Dim()
API 和dynamic_shapes
参数,我们指定每个输入的第一个维度为动态。查看输入arg5_1
和arg6_1
,它们的符号形状为 (s0, 64) 和 (s0, 128),而不是我们作为示例输入传递的 (32, 64) 和 (32, 128) 形状的张量。s0
是一个表示该维度可以是一系列值的符号。exported_program.range_constraints
描述了图中出现的每个符号的范围。在本例中,我们看到s0
的范围为 [2, inf]。由于一些难以在这里解释的技术原因,它们被认为不等于 0 或 1。这不是错误,也不一定意味着导出的程序将无法用于维度 0 或 1。有关此主题的深入讨论,请参阅 0/1 特化问题。
我们还可以指定输入形状之间更具表现力的关系,例如一对形状可能相差一个,一个形状可能是另一个形状的两倍,或者一个形状是偶数。示例
class M(torch.nn.Module):
def forward(self, x, y):
return x + y[1:]
x, y = torch.randn(5), torch.randn(6)
dimx = torch.export.Dim("dimx", min=3, max=6)
dimy = dimx + 1
exported_program = torch.export.export(
M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[s0]", arg1_1: "f32[s0 + 1]"):
# code: return x + y[1:]
slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(arg1_1, 0, 1, 9223372036854775807); arg1_1 = None
add: "f32[s0]" = torch.ops.aten.add.Tensor(arg0_1, slice_1); arg0_1 = slice_1 = None
return (add,)
Graph signature: ExportGraphSignature(
input_specs=[
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg0_1'), target=None, persistent=None),
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg1_1'), target=None, persistent=None)
],
output_specs=[
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)]
)
Range constraints: {s0: ValueRanges(lower=3, upper=6, is_bool=False), s0 + 1: ValueRanges(lower=4, upper=7, is_bool=False)}
一些需要注意的地方
通过为第一个输入指定
{0: dimx}
,我们看到第一个输入的最终形状现在是动态的,为[s0]
。现在,通过为第二个输入指定{0: dimy}
,我们看到第二个输入的最终形状也是动态的。但是,因为我们表达了dimy = dimx + 1
,而不是arg1_1
的形状包含一个新的符号,我们看到它现在正使用arg0_1
中使用的相同符号s0
来表示。我们可以看到dimy = dimx + 1
的关系通过s0 + 1
显示出来。查看范围约束,我们看到
s0
的范围为 [3, 6],这是最初指定的,我们可以看到s0 + 1
的求解范围为 [4, 7]。
序列化¶
要保存 ExportedProgram
,用户可以使用 torch.export.save()
和 torch.export.load()
API。一种约定是使用 .pt2
文件扩展名来保存 ExportedProgram
。
示例
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), torch.randn(5))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
特化¶
理解 torch.export
行为的一个关键概念是 *静态* 值和 *动态* 值之间的区别。
A *dynamic* value 是一个可以在每次运行时变化的值。这些值的行为类似于 Python 函数的普通参数——您可以为参数传递不同的值,并期望您的函数能正常工作。张量 *数据* 被视为动态的。
A *static* value 是一个在导出时固定的值,在导出程序执行期间无法更改。当在追踪期间遇到该值时,导出器将把它视为一个常量,并将其硬编码到图中。
当执行一个操作(例如 x + y
)且所有输入都是静态的时,操作的输出将直接被硬编码到图中,而操作本身不会显示出来(即,它将被常量折叠)。
当一个值被硬编码到图中时,我们说该图已被 *专门化* 到该值。
以下值是静态的
输入张量形状¶
默认情况下,torch.export
将追踪程序,专门化到输入张量的形状,除非通过 dynamic_shapes
参数向 torch.export
指定了一个维度为动态。这意味着如果存在与形状相关的控制流,torch.export
将专门化到使用给定样本输入所采取的分支。例如
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x):
if x.shape[0] > 5:
return x + 1
else:
return x - 1
example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 2]):
add: f32[10, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
return (add,)
条件 (x.shape[0] > 5
) 不出现在 ExportedProgram
中,因为示例输入的静态形状为 (10, 2)。由于 torch.export
针对输入的静态形状进行专门化,因此 else 分支 (x - 1
) 将永远无法到达。为了保留基于张量形状的动态分支行为,需要使用 torch.export.Dim()
来指定输入张量 (x.shape[0]
) 的维度为动态,并且需要 重写 源代码。
请注意,作为模块状态一部分的张量(例如参数和缓冲区)始终具有静态形状。
Python 基本类型¶
torch.export
也针对 Python 基本类型进行专门化,例如 int
、float
、bool
和 str
。但是它们确实具有动态变体,例如 SymInt
、SymFloat
和 SymBool
。
例如
import torch
from torch.export import export
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, const: int, times: int):
for i in range(times):
x = x + const
return x
example_inputs = (torch.rand(2, 2), 1, 3)
exported_program = export(Mod(), example_inputs)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[2, 2], arg1_1, arg2_1):
add: f32[2, 2] = torch.ops.aten.add.Tensor(arg0_1, 1);
add_1: f32[2, 2] = torch.ops.aten.add.Tensor(add, 1);
add_2: f32[2, 2] = torch.ops.aten.add.Tensor(add_1, 1);
return (add_2,)
由于整数是专门化的,因此 torch.ops.aten.add.Tensor
操作都使用硬编码的常量 1
计算,而不是 arg1_1
。如果用户在运行时为 arg1_1
传递了与导出时使用的 1 不同的值(例如 2),这将导致错误。此外,在 for
循环中使用的 times
迭代器也通过 3 个重复的 torch.ops.aten.add.Tensor
调用“内联”到图中,并且输入 arg2_1
从未使用。
Python 容器¶
Python 容器 (List
、Dict
、NamedTuple
等) 被认为具有静态结构。
torch.export 的局限性¶
图中断¶
由于 torch.export
是从 PyTorch 程序中捕获计算图的一次性过程,因此它最终可能会遇到程序中无法追踪的部分,因为几乎不可能支持追踪所有 PyTorch 和 Python 功能。在 torch.compile
的情况下,不支持的操作将导致“图中断”,并且不支持的操作将使用默认的 Python 评估运行。相反,torch.export
将要求用户提供额外的信息或重写其代码的一部分,使其可追踪。由于追踪是基于 TorchDynamo(在 Python 字节码级别进行评估),因此与以前的追踪框架相比,需要重写的代码将少得多。
遇到图中断时,ExportDB 是一个很好的资源,用于了解支持和不支持的程序类型,以及重写程序以使其可追踪的方法。
可以使用 非严格导出 来避免处理这些图中断。
数据/形状相关的控制流¶
当形状未被专门化时,也可能在数据相关的控制流 (if x.shape[0] > 2
) 上遇到图中断,因为追踪编译器不可能处理这种情况,而不会为组合爆炸数量的路径生成代码。在这种情况下,用户需要使用特殊的控制流运算符重写代码。目前,我们支持 torch.cond 来表达类似 if-else 的控制流(更多内容即将推出!)。
运算符缺少假/元/抽象内核¶
追踪时,所有运算符都需要 FakeTensor 内核(也称为元内核、抽象实现)。它用于推断此运算符的输入/输出形状。
有关详细信息,请参阅 torch.library.register_fake()
。
如果您的模型不幸使用了尚未实现 FakeTensor 内核的 ATen 运算符,请提交问题。
API 参考¶
- torch.export.export(mod, args, kwargs=None, *, dynamic_shapes=None, strict=True, preserve_module_call_signature=())[source]¶
export()
使用示例输入接收任意 Python 可调用对象(nn.Module、函数或方法),并以提前编译 (AOT) 方式生成一个追踪图,该图仅表示函数的张量计算,该图随后可以使用不同的输入执行或序列化。健全性保证
追踪时,
export()
会注意到用户程序和底层 PyTorch 运算符内核所做的与形状相关的假设。仅当这些假设成立时,输出ExportedProgram
才被视为有效。追踪对输入张量的形状(而不是值)做出假设。为了使
export()
成功,必须在图捕获时验证这些假设。具体来说对输入张量的静态形状的假设会自动验证,无需额外努力。
对输入张量的动态形状的假设需要使用
Dim()
API 来构造动态维度,并通过dynamic_shapes
参数将它们与示例输入相关联。
如果任何假设无法验证,将引发致命错误。在这种情况下,错误消息将包含对验证假设所需的规范建议修复。例如,
export()
可能会建议对动态维度定义进行以下修复,例如,说出现在与输入x
相关的形状中的dim0_x
,之前定义为Dim("dim0_x")
dim = Dim("dim0_x", max=5)
此示例意味着生成的代码要求输入
x
的维度 0 小于或等于 5 才能有效。您可以检查对动态维度定义的建议修复,然后将它们逐字复制到代码中,而无需更改export()
调用的dynamic_shapes
参数。- 参数
mod (Module) – 我们将追踪此模块的 forward 方法。
dynamic_shapes (Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]]) –
一个可选参数,其类型应为以下之一:1) 一个字典,用于将
f
的参数名称映射到其动态形状规范;2) 一个元组,用于为每个输入按原始顺序指定动态形状规范。如果您正在为关键字参数指定动态性,则需要按原始函数签名中定义的顺序传递它们。张量参数的动态形状可以指定为以下之一:(1) 一个字典,从动态维度索引到
Dim()
类型,其中不需要在此字典中包含静态维度索引,但当它们存在时,应将其映射到 None;或 (2) 一个Dim()
类型的元组/列表或 None,其中Dim()
类型对应于动态维度,静态维度由 None 表示。字典或元组/张量列表的参数通过使用包含规范的映射或序列递归指定。strict (bool) – 启用时(默认),导出函数将通过 TorchDynamo 追踪程序,以确保生成的图形的健壮性。否则,导出的程序将不会验证嵌入到图形中的隐式假设,并可能导致原始模型和导出模型之间的行为差异。当用户需要解决追踪器中的错误,或者只是希望在模型中逐步启用安全性时,这很有用。请注意,这不会影响生成的 IR 规范不同,并且模型将以相同的方式序列化,而不管此处传递什么值。警告:此选项处于实验阶段,使用此选项需谨慎。
- 返回值
一个包含已追踪可调用的
ExportedProgram
。- 返回类型
可接受的输入/输出类型
可接受的输入类型(对于
args
和kwargs
)和输出包括原始类型,即
torch.Tensor
,int
,float
,bool
和str
。数据类,但必须先通过调用
register_dataclass()
来注册它们。(嵌套) 数据结构,包含
dict
,list
,tuple
,namedtuple
和OrderedDict
,其中包含所有上述类型。
- torch.export.save(ep, f, *, extra_files=None, opset_version=None)[source]¶
警告
正在积极开发中,保存的文件可能无法在较新的 PyTorch 版本中使用。
将
ExportedProgram
保存到文件类对象。然后可以使用 Python APItorch.export.load
加载它。- 参数
ep (ExportedProgram) – 要保存的导出程序。
f (Union[str, os.PathLike, io.BytesIO) – 文件类对象(必须实现 write 和 flush)或包含文件名的字符串。
extra_files (Optional[Dict[str, Any]]) – 从文件名到内容的映射,这些内容将作为 f 的一部分存储。
opset_version (Optional[Dict[str, int]]) – opset 名称到该 opset 版本的映射。
示例
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file torch.export.save(ep, 'exported_program.pt2') # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'.decode('utf-8')} torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files)
- torch.export.load(f, *, extra_files=None, expected_opset_version=None)[source]¶
警告
正在积极开发中,保存的文件可能无法在较新的 PyTorch 版本中使用。
加载之前使用
torch.export.save
保存的ExportedProgram
。- 参数
ep (ExportedProgram) – 要保存的导出程序。
f (Union[str, os.PathLike, io.BytesIO) – 文件类对象(必须实现 write 和 flush)或包含文件名的字符串。
extra_files (Optional[Dict[str, Any]]) – 此映射中给定的额外文件名将被加载,它们的内容将存储在提供的映射中。
expected_opset_version (Optional[Dict[str, int]]) – opset 名称到预期 opset 版本的映射。
- 返回值
一个
ExportedProgram
对象。- 返回类型
示例
import torch import io # Load ExportedProgram from file ep = torch.export.load('exported_program.pt2') # Load ExportedProgram from io.BytesIO object with open('exported_program.pt2', 'rb') as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. extra_files = {'foo.txt': ''} # values will be replaced with data ep = torch.export.load('exported_program.pt2', extra_files=extra_files) print(extra_files['foo.txt']) print(ep(torch.randn(5)))
- torch.export.register_dataclass(cls, *, serialized_type_name=None)[source]¶
将数据类注册为
torch.export.export()
的有效输入/输出类型。- 参数
示例
@dataclass class InputDataClass: feature: torch.Tensor bias: int class OutputDataClass: res: torch.Tensor torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) def fn(o: InputDataClass) -> torch.Tensor: res = res=o.feature + o.bias return OutputDataClass(res=res) ep = torch.export.export(fn, (InputDataClass(torch.ones(2, 2), 1), )) print(ep)
- torch.export.dynamic_shapes.Dim(name, *, min=None, max=None)[source]¶
Dim()
构造一个类似于具有范围的命名符号整数的类型。它可以用来描述动态张量维度的多个可能值。请注意,同一个张量的不同动态维度,或不同张量的不同动态维度,可以用同一个类型来描述。
- class torch.export.dynamic_shapes.ShapesCollection[source]¶
用于 dynamic_shapes 的构建器。用于为出现在输入中的张量分配动态形状规范。
- 示例:
args = ({“x”: tensor_x, “others”: [tensor_y, tensor_z]})
dim = torch.export.Dim(…) dynamic_shapes = torch.export.ShapesCollection() dynamic_shapes[tensor_x] = (dim, dim + 1, 8) dynamic_shapes[tensor_y] = {0: dim * 2} # 这等效于以下内容(现在自动生成): # dynamic_shapes = {“x”: (dim, dim + 1, 8), “others”: [{0: dim * 2}, None]}
torch.export(…, args, dynamic_shapes=dynamic_shapes)
- torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(msg, dynamic_shapes)[source]¶
用于处理导出动态形状的建议修复,或自动动态形状。根据约束冲突错误消息和原始动态形状,完善给定的动态形状规范。
对于大多数情况,行为是直接的 - 例如,对于专门化或细化 Dim 范围的建议修复,或建议派生关系的修复,新的动态形状规范将相应更新。
例如建议修复
dim = Dim(‘dim’, min=3, max=6) -> 这只是细化了 dim 的范围 dim = 4 -> 这专门化到一个常数 dy = dx + 1 -> dy 被指定为一个独立的 dim,但实际上与 dx 绑定了这个关系
然而,与派生 dims 相关的建议修复可能会更加复杂。例如,如果为根 dim 提供了一个建议修复,则新的派生 dim 值将根据根来评估。
例如 dx = Dim(‘dx’) dy = dx + 2 dynamic_shapes = {“x”: (dx,), “y”: (dy,)}
建议修复
dx = 4 # 专门化将导致 dy 也专门化 = 6 dx = Dim(‘dx’, max=6) # dy 现在有 max = 8
派生 dims 的建议修复也可以用于表达可除性约束。这涉及创建不绑定到特定输入形状的新根 dims。在这种情况下,根 dims 不会直接出现在新规范中,而是作为其中一个 dims 的根。
例如建议修复
_dx = Dim(‘_dx’, max=1024) # 这不会出现在返回结果中,但 dx 会 dx = 4*_dx # dx 现在可以被 4 整除,最大值为 4096
- class torch.export.ExportedProgram(root, graph, graph_signature, state_dict, range_constraints, module_call_graph, example_inputs=None, constants=None, *, verifiers=None)[source]¶
来自
export()
的程序包。它包含一个torch.fx.Graph
,它表示张量计算,一个 state_dict 包含所有提升参数和缓冲区的张量值,以及各种元数据。您可以像调用原始可调用对象(由
export()
跟踪)一样调用 ExportedProgram,并具有相同的调用约定。要对图执行转换,请使用
.module
属性访问torch.fx.GraphModule
。然后,您可以使用 FX 转换 来重写图。之后,您只需再次使用export()
即可构建正确的 ExportedProgram。- run_decompositions(decomp_table=None, _preserve_ops=())[source]¶
在导出的程序上运行一组分解,并返回一个新的导出程序。默认情况下,我们将运行 Core ATen 分解,以获取 Core ATen 运算符集 中的运算符。
目前,我们不分解联合图。
- 返回类型
- class torch.export.ExportBackwardSignature(gradients_to_parameters: Dict[str, str], gradients_to_user_inputs: Dict[str, str], loss_output: str)[source]¶
- class torch.export.ExportGraphSignature(input_specs, output_specs)[source]¶
ExportGraphSignature
对导出图的输入/输出签名进行建模,导出图是一个具有更强不变性保证的 fx.Graph。导出图是函数式的,不会通过
getattr
节点访问图中像参数或缓冲区这样的“状态”。相反,export()
保证参数、缓冲区和常量张量会被从图中提升为输入。类似地,对缓冲区的任何修改也不会包含在图中,而是将已修改缓冲区的更新值建模为导出图的附加输出。所有输入和输出的排序是
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出以下模块
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的图将是
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的 ExportGraphSignature 将是
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.ModuleCallSignature(inputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], outputs: List[Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument]], in_spec: torch.utils._pytree.TreeSpec, out_spec: torch.utils._pytree.TreeSpec)[source]¶
- class torch.export.ModuleCallEntry(fqn: str, signature: Optional[torch.export.exported_program.ModuleCallSignature] = None)[source]¶
- class torch.export.graph_signature.InputSpec(kind: torch.export.graph_signature.InputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str], persistent: Optional[bool] = None)[source]¶
- class torch.export.graph_signature.OutputSpec(kind: torch.export.graph_signature.OutputKind, arg: Union[torch.export.graph_signature.TensorArgument, torch.export.graph_signature.SymIntArgument, torch.export.graph_signature.ConstantArgument, torch.export.graph_signature.CustomObjArgument, torch.export.graph_signature.TokenArgument], target: Optional[str])[source]¶
- class torch.export.graph_signature.ExportGraphSignature(input_specs, output_specs)[source]¶
ExportGraphSignature
对导出图的输入/输出签名进行建模,导出图是一个具有更强不变性保证的 fx.Graph。导出图是函数式的,不会通过
getattr
节点访问图中像参数或缓冲区这样的“状态”。相反,export()
保证参数、缓冲区和常量张量会被从图中提升为输入。类似地,对缓冲区的任何修改也不会包含在图中,而是将已修改缓冲区的更新值建模为导出图的附加输出。所有输入和输出的排序是
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] Outputs = [*mutated_inputs, *flattened_user_outputs]
例如,如果导出以下模块
class CustomModule(nn.Module): def __init__(self) -> None: super(CustomModule, self).__init__() # Define a parameter self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers self.register_buffer('my_buffer1', torch.tensor(3.0)) self.register_buffer('my_buffer2', torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) self.my_buffer2.add_(1.0) # In-place addition return output
生成的图将是
graph(): %arg0_1 := placeholder[target=arg0_1] %arg1_1 := placeholder[target=arg1_1] %arg2_1 := placeholder[target=arg2_1] %arg3_1 := placeholder[target=arg3_1] %arg4_1 := placeholder[target=arg4_1] %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) return (add_tensor_2, add_tensor_1)
生成的 ExportGraphSignature 将是
ExportGraphSignature( input_specs=[ InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) ], output_specs=[ OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) ] )
- class torch.export.graph_signature.CustomObjArgument(name: str, class_fqn: str, fake_val: Optional[torch._library.fake_class_registry.FakeScriptObject] = None)[source]¶
- class torch.export.unflatten.InterpreterModule(graph)[source]¶
一个使用 torch.fx.Interpreter 进行执行的模块,而不是 GraphModule 通常使用的代码生成。这提供了更好的堆栈跟踪信息,并使调试执行变得更加容易。
- torch.export.unflatten.unflatten(module, flat_args_adapter=None)[source]¶
展开一个 ExportedProgram,生成一个与原始急切模块具有相同模块层次结构的模块。如果您尝试将
torch.export
与另一个期望模块层次结构而不是torch.export
通常生成的扁平图形的系统一起使用,这将非常有用。注意
展开的模块的 args/kwargs 不一定与急切模块匹配,因此进行模块交换(例如
self.submod = new_mod
)不一定有效。如果您需要交换模块,则需要设置torch.export.export()
的preserve_module_call_signature
参数。- 参数
module (ExportedProgram) – 要展开的 ExportedProgram。
flat_args_adapter (Optional[FlatArgsAdapter]) – 如果输入 TreeSpec 与导出的模块不匹配,则调整扁平参数。
- 返回值
UnflattenedModule
的实例,它与导出之前的原始急切模块具有相同的模块层次结构。- 返回类型
UnflattenedModule
- torch.export.passes.move_to_device_pass(ep, location)[source]¶
将导出的程序移动到给定的设备。
- 参数
ep (ExportedProgram) – 要移动的导出程序。
location (Union[torch.device, str, Dict[str, str]]) – 将导出程序移动到的设备。如果为字符串,则将其解释为设备名称。如果为字典,则将其解释为从现有设备到目标设备的映射
- 返回值
已移动的导出程序。
- 返回类型