快捷方式

torch.export 编程模型

本文旨在解释 torch.export.export() 的行为和功能。它旨在帮助你理解 torch.export.export() 如何处理代码。

跟踪基础

torch.export.export() 通过在“示例”输入上跟踪模型的执行,并记录沿跟踪路径观察到的 PyTorch 操作和条件,来捕获表示模型的图。只要后续输入满足相同的条件,此图就可以在不同的输入上运行。

torch.export.export() 的基本输出是一个包含相关元数据的 PyTorch 操作的单一图。此输出的具体格式在 torch.export IR 规范 中介绍。

严格跟踪与非严格跟踪

torch.export.export() 提供了两种跟踪模式。

非严格模式下,我们使用标准的 Python 解释器跟踪程序。你的代码将完全按照 eager 模式执行;唯一的区别是所有 Tensor 都被替换为 伪 Tensor它们具有形状和其他形式的元数据,但没有数据,并封装在 Proxy 对象 中,这些对象将所有操作记录到一个图中。我们还捕获了 Tensor 形状条件 这些条件用于保证生成代码的正确性

严格模式下,我们首先使用 TorchDynamo(一个 Python 字节码分析引擎)跟踪程序。TorchDynamo 实际上并不执行你的 Python 代码。相反,它象征性地分析代码并根据结果构建图。一方面,这种分析允许 torch.export.export() 提供额外的 Python 级别安全性保证(除了像非严格模式那样捕获 Tensor 形状条件外)。另一方面,并非所有 Python 特性都受到此分析的支持。

虽然目前默认的跟踪模式是严格模式,但我们强烈建议使用非严格模式,它很快将成为默认模式。对于大多数模型,Tensor 形状条件足以保证健全性,而额外的 Python 级别安全性保证没有影响;同时,在 TorchDynamo 中遇到不支持的 Python 特性会带来不必要的风险。

在本文档的其余部分,我们假设在非严格模式下进行跟踪;特别是,我们假设所有 Python 特性都受到支持

值:静态值与动态值

理解 torch.export.export() 行为的关键概念是静态值和动态值之间的区别。

静态值

静态值是在导出时固定,且在导出程序的每次执行之间不能更改的值。在跟踪期间遇到该值时,我们将其视为常量并将其硬编码到图中。

当执行操作(例如 x + y)且所有输入均为静态时,操作的输出将直接硬编码到图中,且该操作不会出现在图中(即它被“常量折叠”)。

当值被硬编码到图中时,我们称该图已针对该值进行了特化。例如

import torch

class MyMod(torch.nn.Module):
    def forward(self, x, y):
        z = y + 7
        return x + z

m = torch.export.export(MyMod(), (torch.randn(1), 3))
print(m.graph_module.code)

"""
def forward(self, arg0_1, arg1_1):
    add = torch.ops.aten.add.Tensor(arg0_1, 10);  arg0_1 = None
    return (add,)

"""

在这里,我们将 3 作为 y 的跟踪值;它被视为静态值并添加到 7 中,在图中固化了静态值 10

动态值

动态值是在每次运行时可以更改的值。它的行为就像“正常”函数参数一样:你可以传递不同的输入并期望函数执行正确的操作。

哪些值是静态的,哪些是动态的?

值是静态还是动态取决于其类型

  • 对于 Tensor

    • Tensor 数据被视为动态。

    • Tensor 形状可以被系统视为静态或动态。

      • 默认情况下,所有输入 Tensor 的形状被视为静态。用户可以通过为任何输入 Tensor 指定动态形状来覆盖此行为。

      • 作为模块状态一部分的 Tensor,即参数和缓冲区,始终具有静态形状。

    • 其他形式的 Tensor 元数据(例如 device, dtype)是静态的。

  • Python 基本类型int, float, bool, str, None)是静态的。

    • 某些基本类型有动态变体(SymInt, SymFloat, SymBool)。通常用户不需要处理它们。

  • 对于 Python 标准容器list, tuple, dict, namedtuple

    • 结构(即 listtuple 值的长度,以及 dictnamedtuple 值的键序列)是静态的。

    • 包含的元素递归应用这些规则(基本上是 PyTree 方案),叶子节点是 Tensor 或基本类型。

  • 其他(包括数据类)可以通过 PyTree 注册(见下文),并遵循与标准容器相同的规则。

输入类型

输入将被视为静态或动态,具体取决于它们的类型(如上所述)。

  • 静态输入将被硬编码到图中,并在运行时传递不同的值将导致错误。请记住,这些值主要是基本类型的值。

  • 动态输入行为类似于“正常”函数输入。请记住,这些值主要是 Tensor 类型的值。

默认情况下,程序可用的输入类型包括

  • Tensor

  • Python 基本类型(int, float, bool, str, None

  • Python 标准容器(list, tuple, dict, namedtuple

自定义输入类型

此外,你还可以定义自己的(自定义)类并将其用作输入类型,但这需要你将此类注册为 PyTree。

这是一个使用实用程序注册用作输入类型的数据类的示例。

@dataclass
class Input:
    f: torch.Tensor
    p: torch.Tensor

torch.export.register_dataclass(Input)

class M(torch.nn.Module):
    def forward(self, x: Input):
        return x.f + 1

torch.export.export(M(), (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),))

可选输入类型

对于程序中未传入的可选输入,torch.export.export() 将特化为它们的默认值。因此,导出的程序将要求用户显式传入所有参数,并会丢失默认行为。例如

class M(torch.nn.Module):
    def forward(self, x, y=None):
        if y is not None:
            return y * x
        return x + x

# Optional input is passed in
ep = torch.export.export(M(), (torch.randn(3, 3), torch.randn(3, 3)))
print(ep)
"""
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]", y: "f32[3, 3]"):
            # File: /data/users/angelayi/pytorch/moo.py:15 in forward, code: return y * x
            mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(y, x);  y = x = None
            return (mul,)
"""

# Optional input is not passed in
ep = torch.export.export(M(), (torch.randn(3, 3),))
print(ep)
"""
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]", y):
            # File: /data/users/angelayi/pytorch/moo.py:16 in forward, code: return x + x
            add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, x);  x = None
            return (add,)
"""

控制流:静态与动态

PyTorch torch.export.export() 支持控制流。控制流的行为取决于你分支的值是静态还是动态。

静态控制流

Python 对静态值的控制流得到透明支持。(请记住,静态值包括静态形状,因此对静态形状的控制流也属于这种情况。)

如上所述,我们“固化”静态值,因此导出的图永远不会看到任何对静态值的控制流。

对于 if 语句,我们将继续跟踪导出时采用的分支。对于 forwhile 语句,我们将通过展开循环来继续跟踪。

动态控制流:依赖形状与依赖数据

当控制流中涉及的值是动态的时,它可能依赖于动态形状或动态数据。考虑到编译器跟踪时使用的是形状信息而不是数据,这些情况下对编程模型的影响是不同的。

动态形状依赖控制流

当控制流中涉及的值是动态形状时,在大多数情况下,我们在跟踪期间也会知道动态形状的具体值:有关编译器如何跟踪此信息的更多详细信息,请参阅下一节。

在这些情况下,我们称控制流是形状依赖的。我们使用动态形状的具体值来评估条件TrueFalse 并继续跟踪(如上所述),此外还会发出与刚刚评估的条件对应的守卫。

否则,控制流被视为数据依赖的。我们无法评估条件为 TrueFalse,因此无法继续跟踪,必须在导出时引发错误。请参阅下一节。

动态数据依赖控制流

数据依赖的动态值控制流受到支持,但你必须使用 PyTorch 的显式操作符之一来继续跟踪。使用 Python 控制流语句处理动态值是不允许的,因为编译器无法评估继续跟踪所需的条件,因此必须在导出时引发错误。

我们提供了操作符来表达动态值的通用条件和循环,例如 torch.condtorch.map。请注意,只有当你确实需要数据依赖的控制流时才需要使用这些操作符。

这是一个关于数据依赖条件 x.sum() > 0if 语句示例,其中 x 是一个输入 Tensor,使用 torch.cond 重写。现在,两个分支都被跟踪,而不是必须决定跟踪哪个分支。

class M_old(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x.sin()
        else:
            return x.cos()

class M_new(torch.nn.Module):
    def forward(self, x):
        return torch.cond(
            pred=x.sum() > 0,
            true_fn=lambda x: x.sin(),
            false_fn=lambda x: x.cos(),
            operands=(x,),
        )

数据依赖控制流的一个特殊情况是它涉及无支持的动态形状:通常是某些中间 Tensor 的形状,它依赖于输入数据而不是输入形状(因此不依赖于形状)。在这种情况下,你可以提供一个断言来决定条件是 True 还是 False,而不是使用控制流操作符。给定这样的断言,我们可以继续跟踪,并发出一个守卫,如上所述。

我们提供操作符来表达对动态形状的断言,例如 torch._check。请注意,只有当对数据依赖的动态形状存在控制流时才需要使用此操作符。

这是一个涉及数据依赖动态形状的条件 nz.shape[0] > 0if 语句示例,其中 nz 是调用 torch.nonzero() 的结果,这是一个输出形状依赖于输入数据(因此不依赖于形状)的操作符。在这种情况下,你可以使用 torch._check 添加断言来有效地决定跟踪哪个分支,而不是重写它。

class M_old(torch.nn.Module):
    def forward(self, x):
        nz = x.nonzero()
        if nz.shape[0] > 0:
            return x.sin()
        else:
            return x.cos()

class M_new(torch.nn.Module):
    def forward(self, x):
        nz = x.nonzero()
        torch._check(nz.shape[0] > 0)
        if nz.shape[0] > 0:
            return x.sin()
        else:
            return x.cos()

符号形状基础

在跟踪期间,动态 Tensor 形状及其条件被编码为“符号表达式”。(相比之下,静态 Tensor 形状及其条件仅仅是 intbool 值。)

一个符号就像一个变量;它描述一个动态 Tensor 形状。

随着跟踪的进行,中间 Tensor 的形状可以用更通用的表达式来描述,通常涉及整数算术运算符。这是因为对于大多数 PyTorch 操作符,输出 Tensor 的形状可以描述为输入 Tensor 形状的函数。例如,torch.cat() 的输出形状是其输入形状的总和。

此外,当我们在程序中遇到控制流时,我们会创建布尔表达式,通常涉及关系运算符,描述沿着跟踪路径的条件。这些表达式会被评估以决定跟踪程序中的哪条路径,并记录在形状环境中,以保证跟踪路径的正确性并评估随后创建的表达式。

接下来我们将简要介绍这些子系统。

PyTorch 操作符的伪实现

回想一下,在跟踪期间,我们使用伪 Tensor 执行程序,这些 Tensor 没有数据。通常我们无法使用伪 Tensor 调用 PyTorch 操作符的实际实现。因此,每个操作符都需要一个额外的伪(也称为“元”)实现,该实现输入和输出伪 Tensor,并在形状和伪 Tensor 携带的其他形式的元数据方面与实际实现的行为匹配。

例如,请注意 torch.index_select() 的伪实现如何使用输入形状计算输出形状(同时忽略输入数据并返回空的输出数据)。

def meta_index_select(self, dim, index):
    result_size = list(self.size())
    if self.dim() > 0:
        result_size[dim] = index.numel()
    return self.new_empty(result_size)

形状传播:有支持的动态形状与无支持的动态形状

形状通过 PyTorch 操作符的伪实现进行传播。

理解动态形状传播的关键概念是有支持的无支持的动态形状之间的区别:我们知道前者的具体值,但不知道后者的具体值。

形状的传播,包括跟踪有支持的和无支持的动态形状,按以下方式进行

  • 表示输入的 Tensor 的形状可以是静态或动态的。当为动态时,它们由符号描述;此外,由于我们知道用户在导出时提供的“真实”示例输入所给出的具体值,因此此类符号是有支持的

  • 操作符的输出形状由其伪实现计算,可以是静态或动态的。当为动态时,通常由符号表达式描述。此外

    • 如果输出形状仅依赖于输入形状,则当输入形状全部为静态或有支持的动态时,输出形状也是静态或有支持的动态。

    • 另一方面,如果输出形状依赖于输入数据,则它必然是动态的,而且,因为我们无法知道它的具体值,它是无支持的

控制流:守卫和断言

当遇到形状条件时,它要么仅涉及静态形状,在这种情况下它是一个 bool,要么涉及动态形状,在这种情况下它是一个符号布尔表达式。对于后者

  • 当条件仅涉及有支持的动态形状时,我们可以使用这些动态形状的具体值来评估条件为 TrueFalse。然后我们可以在形状环境中添加一个守卫,说明相应的符号布尔表达式为 TrueFalse,并继续跟踪。

  • 否则,条件涉及无支持的动态形状。通常,在没有额外信息的情况下,我们无法评估此类条件;因此我们无法继续跟踪,必须在导出时引发错误。用户需要使用显式的 PyTorch 操作符才能继续跟踪。此信息作为守卫添加到形状环境中,并且可能有助于评估其他随后遇到的条件为 TrueFalse

模型导出后,对有支持动态形状的任何守卫都可以理解为对输入动态形状的条件。这些条件会根据必须提供给导出的动态形状规范进行验证,该规范描述了示例输入以及所有未来输入为使生成代码正确必须满足的动态形状条件。更精确地说,动态形状规范在逻辑上必须蕴含生成的守卫,否则将在导出时引发错误(并给出动态形状规范的建议修复)。另一方面,当对有支持动态形状没有生成守卫时(特别是当所有形状都是静态时),无需为导出提供动态形状规范。通常,动态形状规范会转换为对生成代码输入的运行时断言。

最后,对无支持动态形状的任何守卫都会转换为“内联”运行时断言。这些断言会添加到生成代码中创建这些无支持动态形状的位置:通常是在数据依赖操作符调用之后。

允许的 PyTorch 操作符

允许使用所有 PyTorch 操作符。

自定义操作符

此外,你可以定义和使用自定义操作符。定义自定义操作符包括为其定义一个伪实现,就像任何其他 PyTorch 操作符一样(请参阅上一节)。

这是一个自定义 sin 操作符的示例,它封装了 NumPy,及其注册的(简单的)伪实现。

@torch.library.custom_op("mylib::sin", mutates_args=())
def sin(x: Tensor) -> Tensor:
    x_np = x.numpy()
    y_np = np.sin(x_np)
    return torch.from_numpy(y_np)

@torch.library.register_fake("mylib::sin")
def _(x: Tensor) -> Tensor:
    return torch.empty_like(x)

有时,你的自定义操作符的伪实现会涉及数据依赖的形状。这是一个自定义 nonzero 操作符的伪实现可能看起来的样子。

...

@torch.library.register_fake("mylib::custom_nonzero")
def _(x):
    nnz = torch.library.get_ctx().new_dynamic_size()
    shape = [nnz, x.dim()]
    return x.new_empty(shape, dtype=torch.int64)

模块状态:读取与更新

模块状态包括参数、缓冲区和普通属性。

  • 普通属性可以是任何类型。

  • 另一方面,参数和缓冲区始终是 Tensor。

模块状态可以是动态或静态的,具体取决于如上所述的类型。例如,self.training 是一个 bool,这意味着它是静态的;另一方面,任何参数或缓冲区都是动态的。

模块状态中包含的任何 Tensor 的形状不能是动态的,即这些形状在导出时固定,且在导出程序的每次执行之间不能更改。

访问规则

所有模块状态都必须初始化。在导出时访问尚未初始化的模块状态会导致错误发生。

总是允许读取模块状态.

更新模块状态是可能的,但必须遵循以下规则

  • 静态常规属性(例如,原始类型)可以更新。读取和更新可以自由交错进行,正如预期一样,任何读取都将始终看到最新更新的值。由于这些属性是静态的,我们也会将值嵌入其中,因此生成的代码将不会包含实际“获取”或“设置”这些属性的指令。

  • 动态常规属性(例如,Tensor 类型)无法更新。要更新它,必须在模块初始化期间将其注册为 buffer。

  • Buffer 可以更新,更新可以是原地更新(例如,self.buffer[:] = ...)或者非原地更新(例如,self.buffer = ...)。

  • Parameter 无法更新。通常 parameter 只在训练期间更新,而非推理期间。我们建议使用 torch.no_grad() 进行导出,以避免在导出时更新 parameter。

Functionalization 的影响

任何被读取和/或更新的动态模块状态会(相应地)作为生成代码的输入和/或输出被“提升”(lifted)。

导出的程序会与生成代码一起存储 parameter 和 buffer 的初始值以及其他 Tensor 属性的常量值。

文档

访问 PyTorch 的完整开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并解答你的问题

查看资源