torch.fx¶
概述¶
FX 是一个工具包,供开发者用来转换 nn.Module 实例。 FX 由三个主要组件组成:一个**符号追踪器**、一个**中间表示**和 **Python 代码生成**。 下面演示了这些组件的实际应用
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
**符号追踪器**执行 Python 代码的“符号执行”。 它通过代码馈送假值,称为 Proxies。 这些 Proxies 上的操作被记录下来。 有关符号追踪的更多信息,请参见 symbolic_trace() 和 Tracer 文档。
**中间表示**是记录符号追踪期间的操作的容器。 它由一个节点列表组成,这些节点表示函数输入、调用点(到函数、方法或 torch.nn.Module 实例)和返回值。 有关 IR 的更多信息,请参见 Graph 的文档。 IR 是应用转换的格式。
**Python 代码生成**使 FX 成为 Python 到 Python(或模块到模块)的转换工具包。 对于每个 Graph IR,我们可以创建有效的 Python 代码,以匹配 Graph 的语义。 此功能封装在 GraphModule 中,它是一个 torch.nn.Module 实例,其中包含一个 Graph 以及从 Graph 生成的 forward 方法。
总而言之,这个组件管道(符号追踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。 此外,这些组件可以单独使用。 例如,可以单独使用符号追踪来捕获代码的一种形式以进行分析(而不是转换)的目的。 代码生成可用于以编程方式生成模型,例如从配置文件中生成。 FX 有很多用途!
可以在 examples 存储库中找到一些示例转换。
编写转换¶
什么是 FX 转换? 从本质上讲,它是一个如下所示的函数。
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
您的转换将会接收一个 torch.nn.Module,从中获取一个 Graph,进行一些修改,然后返回一个新的 torch.nn.Module。您应该将您的 FX 转换返回的 torch.nn.Module 视为与常规 torch.nn.Module 相同 – 您可以将其传递给另一个 FX 转换,可以将其传递给 TorchScript,或者您可以运行它。确保您的 FX 转换的输入和输出都是一个 torch.nn.Module 将允许组合。
注意
也可以修改现有的 GraphModule 而不是创建一个新的,如下所示
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
请注意,您必须调用 GraphModule.recompile() 以使 GraphModule 上生成的 forward() 方法与修改后的 Graph 同步。
鉴于您已传入一个 torch.nn.Module,该模块已被跟踪到 Graph 中,现在您可以采用两种主要方法来构建新的 Graph。
图的快速入门¶
有关图的语义的完整处理可以在 Graph 文档中找到,但我们将在此处介绍基础知识。 Graph 是一种数据结构,表示 GraphModule 上的一个方法。 此所需的信息是
该方法的输入是什么?
该方法内部运行的操作是什么?
该方法的输出(即返回值)是什么?
这三个概念都用 Node 实例表示。 让我们用一个简短的例子来看看这意味着什么
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
这里我们定义一个模块 MyModule 用于演示目的,实例化它,对其进行符号追踪,然后调用 Graph.print_tabular() 方法打印出一个表格,显示此 Graph 的节点
操作码
名称
目标
参数
关键字参数
占位符
x
x
()
{}
get_attr
linear_weight
linear.weight
()
{}
call_function
add_1
<built-in function add>
(x, linear_weight)
{}
call_module
linear_1
linear
(add_1,)
{}
call_method
relu_1
relu
(linear_1,)
{}
call_function
sum_1
<built-in method sum …>
(relu_1,)
{‘dim’: -1}
call_function
topk_1
<built-in method topk …>
(sum_1, 3)
{}
输出
输出
输出
(topk_1,)
{}
我们可以使用这些信息来回答上面提出的问题。
该方法的输入是什么? 在 FX 中,方法输入通过特殊的
placeholder节点指定。 在这种情况下,我们有一个单独的placeholder节点,其target为x,这意味着我们有一个名为 x 的(非自)参数。该方法中的操作是什么?
get_attr、call_function、call_module和call_method节点表示该方法中的操作。 有关所有这些的语义的完整处理可以在Node文档中找到。该方法的返回值是什么?
Graph中的返回值由一个特殊的output节点指定。
鉴于我们现在了解了代码在 FX 中如何表示的基础知识,我们现在可以探索如何编辑 Graph。
图操作¶
直接图操作¶
构建这个新 Graph 的一种方法是直接操作旧的图。 为了帮助实现这一点,我们可以简单地获取从符号追踪获得的 Graph 并修改它。 例如,假设我们希望用 torch.mul() 调用替换 torch.add() 调用。
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
我们还可以进行更复杂的 Graph 重写,例如删除或附加节点。 为了帮助进行这些转换,FX 提供了用于转换图的实用函数,可以在 Graph 文档中找到。 下面可以找到使用这些 API 附加 torch.relu() 调用的示例。
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
对于仅包含替换的简单转换,您还可以使用 子图重写器。
使用 replace_pattern() 进行子图重写¶
FX 还提供了在直接图操作之上的另一层自动化。 replace_pattern() API 本质上是一个用于编辑 Graphs 的“查找/替换”工具。 它允许您指定一个 pattern 和 replacement 函数,它将跟踪这些函数,在 pattern 图中查找操作组的实例,并用 replacement 图的副本替换这些实例。 这可以帮助大大自动化繁琐的图操作代码,因为转换变得更加复杂,这些代码可能会变得笨拙。
代理/重新跟踪¶
另一种操作 Graph 的方法是重用符号追踪中使用的 Proxy 机制。 例如,假设我们想编写一个转换,将 PyTorch 函数分解成更小的操作。 它会将每个 F.relu(x) 调用转换为 (x > 0) * x。 一种可能性是执行必要的图重写,以在 F.relu 之后插入比较和乘法,然后清理原始的 F.relu。 但是,我们可以通过使用 Proxy 对象自动将操作记录到 Graph 中来自动化此过程。
要使用此方法,我们编写想要作为常规 PyTorch 代码插入的操作,并使用 Proxy 对象作为参数来调用该代码。 这些 Proxy 对象将捕获对它们执行的操作,并将它们附加到 Graph。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式图操作之外,使用 Proxy 还允许您将重写规则指定为本机 Python 代码。 对于需要大量重写规则的转换(例如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。 请注意,在调用 Proxy 时,我们还传递了一个指向底层变量 graph 的 tracer。 这样做是为了防止图中的操作是 n 元的(例如,加法是一个二元运算符),从而避免调用 Proxy 创建 graph tracer 的多个实例,这可能会导致意外的运行时错误。 我们推荐这种使用 Proxy 的方法,尤其是当不能安全地假设底层运算符是一元的时候。
解释器模式¶
FX 中一种有用的代码组织模式是循环遍历 Graph 中的所有 Node 并执行它们。 这可以用于多种用途,包括运行时分析流经图的值或通过使用 Proxy 进行重新追踪来转换代码。 例如,假设我们想要运行一个 GraphModule 并记录在运行时看到的节点上的 torch.Tensor 形状和 dtype 属性。 如下所示:
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
如您所见,FX 的完整解释器并不是那么复杂,但它非常有用。 为了简化此模式的使用,我们提供了 Interpreter 类,该类以一种可以通过方法重写来覆盖解释器执行的某些方面的方式包含上述逻辑。
除了执行操作之外,我们还可以通过将 Proxy 值通过解释器来生成新的 Graph。 同样,我们提供了 Transformer 类来包含此模式。 Transformer 的行为类似于 Interpreter,但是您不会调用 run 方法从 Module 获取具体的输出值,而是调用 Transformer.transform() 方法来返回一个新的 GraphModule,该 GraphModule 受您安装为重写方法的任何转换规则的约束。
调试¶
简介¶
通常,在编写转换的过程中,我们的代码并不完全正确。 在这种情况下,我们可能需要进行一些调试。 关键是向后工作:首先,检查调用生成的模块的结果以证明或反驳正确性。 然后,检查并调试生成的代码。 然后,调试导致生成代码的转换过程。
如果您不熟悉调试器,请参阅辅助部分 可用的调试器。
转换编写中的常见陷阱¶
非确定性
set迭代顺序。 在 Python 中,set数据类型是无序的。 例如,使用set来包含Node之类的对象集合可能会导致意外的非确定性。 一个例子是迭代一组Node以将它们插入到Graph中。 由于set数据类型是无序的,因此输出程序中操作的顺序将是不确定的,并且可能会在程序调用之间发生变化。 推荐的替代方法是使用dict数据类型,该数据类型从 Python 3.7 开始 (以及从 cPython 3.6 开始) 是 插入有序的。 通过将要删除重复项的值存储在dict的键中,可以将dict等效地用作集合。
检查模块的正确性¶
由于大多数深度学习模块的输出由浮点 torch.Tensor 实例组成,因此检查两个 torch.nn.Module 的结果之间的等效性并不像进行简单的相等性检查那么简单。 为了解释这一点,让我们使用一个例子
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在这里,我们尝试使用 == 相等运算符来检查两个深度学习模型的值的相等性。 但是,由于该运算符返回张量而不是布尔值,因此这是定义不明确的,而且由于浮点值的比较应使用误差幅度(或 epsilon)来解决浮点运算的非交换性(有关更多详细信息,请参见 此处)。 我们可以改用 torch.allclose(),它将考虑到相对和绝对容差阈值,从而给出近似比较
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们工具箱中的第一个工具,用于检查与参考实现相比,转换后的模块是否按预期运行。
调试生成的代码¶
由于 FX 在 GraphModules 上生成 forward() 函数,因此使用传统的调试技术(如 print 语句或 pdb)并不那么直接。幸运的是,我们可以使用几种技术来调试生成的代码。
使用 pdb¶
调用 pdb 来单步调试运行中的程序。虽然表示 Graph 的代码不在任何源文件中,但我们仍然可以在调用 forward 传递时使用 pdb 手动单步调试它。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
打印生成的代码¶
如果您想多次运行相同的代码,那么使用 pdb 单步调试到正确的代码可能会有点乏味。在这种情况下,一种方法是简单地将生成的 forward 传递复制粘贴到您的代码中,并从那里进行检查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 GraphModule 中的 to_folder 函数¶
GraphModule.to_folder() 是 GraphModule 中的一个方法,允许您将生成的 FX 代码转储到一个文件夹中。虽然像 打印生成的代码 中那样将 forward 传递复制到代码中通常就足够了,但使用 to_folder 检查模块和参数可能会更容易。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
运行上面的示例后,我们可以查看 foo/module.py 中的代码,并根据需要进行修改(例如,添加 print 语句或使用 pdb)来调试生成的代码。
调试转换¶
现在我们已经确定转换正在创建不正确的代码,是时候调试转换本身了。首先,我们将检查文档中的 符号追踪的限制 部分。一旦我们验证了追踪工作正常,目标就是弄清楚在我们的 GraphModule 转换过程中出了什么问题。在 编写转换 中可能有一个快速的答案,但如果没有,有几种方法可以检查我们追踪的模块
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
output output output (add,) {}
"""
使用上面的实用函数,我们可以比较应用转换之前和之后追踪的模块。有时,一个简单的视觉比较就足以追踪到错误。如果仍然不清楚发生了什么,像 pdb 这样的调试器可能是一个好的下一步。
以上面的例子为例,考虑以下代码
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
使用上面的例子,假设对 print(traced) 的调用显示我们的转换中存在错误。我们想使用调试器找到出错的地方。我们启动一个 pdb 会话。我们可以通过在 transform_graph(traced) 上中断来查看转换期间发生了什么,然后按 s 来“单步进入”对 transform_graph(traced) 的调用。
我们也可以通过编辑 print_tabular 方法来打印 Graph 中 Node 的不同属性来获得好运。(例如,我们可能想查看 Node 的 input_nodes 和 users。)
可用的调试器¶
最常见的 Python 调试器是 pdb。您可以通过在命令行中输入 python -m pdb FILENAME.py 来以“调试模式”启动程序,其中 FILENAME 是您要调试的文件的名称。之后,您可以使用 pdb 调试器命令逐步移动您的运行程序。通常在启动 pdb 时设置断点 (b LINE-NUMBER),然后调用 c 来运行程序直到该点。这可以防止您必须单步执行每行代码(使用 s 或 n)才能到达您想要检查的代码部分。或者,您可以在要中断的行之前编写 import pdb; pdb.set_trace()。如果您添加 pdb.set_trace(),您的程序将在您运行它时自动以调试模式启动。(换句话说,您可以只在命令行中输入 python FILENAME.py 而不是 python -m pdb FILENAME.py。)一旦您以调试模式运行您的文件,您就可以单步执行代码并使用某些命令检查程序的内部状态。网上有很多关于 pdb 的优秀教程,包括 RealPython 的 “Python Debugging With Pdb”。
像 PyCharm 或 VSCode 这样的 IDE 通常内置了调试器。在您的 IDE 中,您可以选择 a) 通过在您的 IDE 中拉出一个终端窗口(例如,VSCode 中的 View → Terminal)来使用 pdb,或者 b) 使用内置的调试器(通常是 pdb 的图形化包装器)。
符号追踪的限制¶
FX 使用一种 符号追踪(也称为 符号执行)系统来捕获程序语义,以便进行转换/分析。该系统是 追踪 的,因为它执行程序(实际上是 torch.nn.Module 或函数)来记录操作。它是 符号 的,因为在此执行期间流经程序的数据不是真实数据,而是符号(FX 行话中的 Proxy)。
虽然符号追踪适用于大多数神经网络代码,但它有一些限制。
动态控制流¶
符号追踪的主要限制是它目前不支持动态控制流。也就是说,条件可能取决于程序输入值的循环或 if 语句。
例如,让我们检查以下程序
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() > 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if 语句的条件依赖于 x.sum() 的值,而 x.sum() 的值又依赖于 x 的值,x 是函数输入。由于 x 可以改变(即,如果您将新的输入张量传递给跟踪的函数),这就是动态控制流。回溯会向上回溯您的代码,以显示发生这种情况的位置。
静态控制流¶
另一方面,支持所谓的静态控制流。静态控制流是循环或 if 语句,其值在调用之间不会改变。通常,在 PyTorch 程序中,此控制流出现在基于超参数对模型架构做出决策的代码中。作为一个具体的例子
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if-statement if self.do_activation 不依赖于任何函数输入,因此它是静态的。do_activation 可以被认为是一个超参数,并且具有该参数不同值的 MyModule 的不同实例的追踪具有不同的代码。这是一种有效的模式,符号追踪支持这种模式。
许多动态控制流的实例在语义上是静态控制流。通过消除对输入值的数据依赖性,例如通过将值移动到 Module 属性或在符号追踪期间将具体值绑定到参数,可以使这些实例支持符号追踪。
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
对于真正的动态控制流,包含此代码的程序部分可以被追踪为对 Method 的调用(请参阅 使用 Tracer 类自定义追踪)或函数(请参阅 wrap()),而不是通过它们进行追踪。
非 torch 函数¶
FX 使用 __torch_function__ 作为拦截调用的机制 (更多信息请参阅 技术概述)。某些函数,例如内置的 Python 函数或 math 模块中的函数,不在 __torch_function__ 的涵盖范围内,但我们仍然希望在符号追踪中捕获它们。例如:
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
该错误告诉我们内置函数 len 不被支持。 我们可以使用 wrap() API,将此类函数记录到追踪中作为直接调用。
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用 Tracer 类自定义追踪¶
Tracer 类是 symbolic_trace 实现的基础类。 可以通过子类化 Tracer 来自定义跟踪的行为,如下所示:
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
叶子模块¶
叶子模块是在符号追踪中显示为调用而不是被追踪的模块。 默认的叶子模块集合是标准的 torch.nn 模块实例集合。例如:
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
可以通过重写 Tracer.is_leaf_module() 自定义叶子模块集合。
其他¶
目前无法追踪 Tensor 构造函数 (例如
torch.zeros、torch.ones、torch.rand、torch.randn、torch.sparse_coo_tensor)。可以使用确定性构造函数 (
zeros、ones),并且它们生成的值将作为常量嵌入到追踪中。 只有当这些构造函数的参数引用动态输入大小时,这才会出现问题。 在这种情况下,ones_like或zeros_like可能是可行的替代方案。非确定性构造函数 (
rand、randn) 将在追踪中嵌入单个随机值。 这可能不是预期的行为。 一种解决方法是将torch.randn包装在torch.fx.wrap函数中,然后调用该函数。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
此行为可能会在未来的版本中修复。
类型注解
支持 Python 3 风格的类型注解 (例如
func(x : torch.Tensor, y : int) -> torch.Tensor),并且符号追踪将保留它们。目前不支持 Python 2 风格的注释类型注解
# type: (torch.Tensor, int) -> torch.Tensor。目前不支持函数内局部名称上的注解。
关于
training标志和子模块的注意事项当使用函数式编程 (如
torch.nn.functional.dropout) 时,通常会将 training 参数作为self.training传入。 在 FX 追踪期间,这很可能会作为常量值被烘焙到其中。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! Mismatched elements: 15 / 15 (100.0%) Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) """
但是,当使用标准的
nn.Dropout()子模块时,training 标志会被封装起来,并且由于保留了nn.Module对象模型,因此可以进行更改。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
由于这种差异,请考虑将与
training标志动态交互的模块标记为叶子模块。
API 参考¶
- torch.fx.symbolic_trace(root, concrete_args=None)[source][source]¶
符号追踪 API
给定一个
nn.Module或函数实例root,此函数将返回一个GraphModule,该GraphModule是通过记录追踪root时看到的操作来构造的。concrete_args允许您部分专门化您的函数,无论是删除控制流还是数据结构。例如:
def f(a, b): if b == True: return a else: return a * 2
由于控制流的存在,FX 通常无法追踪此函数。 但是,我们可以使用 concrete_args 来专门化 b 的值以进行追踪。
f = fx.symbolic_trace(f, concrete_args={"b": False}) assert f(3, False) == 6
请注意,尽管您仍然可以传入不同的 b 值,但它们将被忽略。
我们还可以使用 concrete_args 从函数中消除数据结构处理。 这将使用 pytrees 来展平您的输入。 为避免过度专门化,请为不应专门化的值传入 fx.PH。 例如:
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) assert f({"a": 1, "b": 2, "c": 4}) == 7
- 参数
root (Union[torch.nn.Module, Callable]) – 要被追踪并转换为 Graph 表示形式的模块或函数。
concrete_args (Optional[Dict[str, any]]) – 要部分专门化的输入
- 返回值
从
root记录的操作创建的模块。- 返回类型
注意
保证此 API 的向后兼容性。
- torch.fx.wrap(fn_or_name)[source][source]¶
可以在模块级范围调用此函数,以将 fn_or_name 注册为“叶子函数”。 “叶子函数”将被保留为 FX 追踪中的 CallFunction 节点,而不是被追踪。
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap("my_custom_function") def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
此函数也可以等效地用作装饰器
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
可以将包装的函数视为“叶子函数”,类似于“叶子模块”的概念,也就是说,它们是在 FX 追踪中保留为调用的函数,而不是被追踪的函数。
- 参数
fn_or_name (Union[str, Callable]) – 当调用时要插入到图中的函数或全局函数的名称
注意
保证此 API 的向后兼容性。
- class torch.fx.GraphModule(*args, **kwargs)[source][source]¶
GraphModule 是从 fx.Graph 生成的 nn.Module。 Graphmodule 具有
graph属性,以及从该graph生成的code和forward属性。警告
当
graph被重新赋值时,code和forward将自动重新生成。 但是,如果在不重新分配graph属性本身的情况下编辑graph的内容,则必须调用recompile()来更新生成的代码。注意
保证此 API 的向后兼容性。
- __init__(root, graph, class_name='GraphModule')[source][source]¶
构造一个 GraphModule。
- 参数
root (Union[torch.nn.Module, Dict[str, Any]) –
root可以是一个 nn.Module 实例,也可以是一个将字符串映射到任意属性类型的字典 (Dict)。 如果root是一个 Module,则 Graph 的 Nodes 的target字段中对基于 Module 的对象的任何引用(通过限定名称)都将从root的 Module 层级结构中的相应位置复制到 GraphModule 的模块层级结构中。 如果root是一个字典,则 Node 的target中找到的限定名称将在字典的键中直接查找。 字典映射到的对象将被复制到 GraphModule 的模块层级结构中的适当位置。graph (Graph) –
graph包含此 GraphModule 应用于代码生成的节点。class_name (str) –
name表示此 GraphModule 的名称,用于调试目的。 如果未设置,则所有错误消息都将报告为源自GraphModule。 将此设置为root的原始名称或在您的转换上下文中合理的名称可能会有所帮助。
注意
保证此 API 的向后兼容性。
- add_submodule(target, m)[source][source]¶
将给定的子模块添加到
self。如果
target的子路径中尚不存在 Module,则会安装空的 Module。- 参数
- 返回值
- 是否可以插入子模块。 对于
此方法返回 True,
target表示的链中的每个对象必须满足以下任一条件:a) 尚不存在,或 b) 引用一个nn.Module(而不是参数或其他属性)。
- 返回类型
注意
保证此 API 的向后兼容性。
- delete_all_unused_submodules()[source][source]¶
从
self中删除所有未使用的子模块。如果以下任一项为真,则认为一个 Module “已使用”: 1. 它具有已使用的子模块 2. 其 forward 通过
call_module节点直接调用 3. 它具有一个非 Module 属性,该属性从get_attr节点使用。可以调用此方法来清理一个
nn.Module,而无需手动在每个未使用的子模块上调用delete_submodule。注意
保证此 API 的向后兼容性。
- delete_submodule(target)[source][source]¶
从
self中删除给定的子模块。如果
target不是有效目标,则不会删除该模块。- 参数
target (str) – 新子模块的完全限定字符串名称(有关如何指定完全限定字符串,请参见
nn.Module.get_submodule中的示例)。- 返回值
- 目标字符串是否引用了
我们要删除的子模块。 返回值
False表示target不是对子模块的有效引用。
- 返回类型
注意
保证此 API 的向后兼容性。
- print_readable(print_output=True, include_stride=False, include_device=False, colored=False)[source][source]¶
返回为当前 GraphModule 及其子 GraphModule 生成的 Python 代码
警告
此 API 尚处于实验阶段,并且不向后兼容。
- class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source][source]¶
Graph是 FX 中间表示中使用的主要数据结构。 它由一系列Node组成,每个Node代表调用点(或其他语法结构)。Node的列表共同构成一个有效的 Python 函数。例如,以下代码
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk( torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 ) m = MyModule() gm = torch.fx.symbolic_trace(m)
将产生以下 Graph
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1有关
Graph中表示的操作的语义,请参见Node。注意
保证此 API 的向后兼容性。
- __init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source][source]¶
构造一个空的 Graph。
注意
保证此 API 的向后兼容性。
- call_function(the_function, args=None, kwargs=None, type_expr=None)[source][source]¶
插入一个
call_functionNode到Graph中。一个call_function节点表示对一个 Python 可调用对象(callable)的调用,由the_function指定。- 参数
the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 运算符、Python 函数或
builtins或operator命名空间的成员。args (Optional[Tuple[Argument, ...]]) – 要传递给被调用函数的位置参数。
kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用函数的关键字参数。
type_expr (Optional[Any]) – 可选的类型注解,表示此节点的输出将具有的 Python 类型。
- 返回值
新创建并插入的
call_function节点。- 返回类型
注意
与
Graph.create_node()一样,此方法适用相同的插入点和类型表达式规则。注意
保证此 API 的向后兼容性。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[source][source]¶
插入一个
call_methodNode到Graph中。一个call_method节点表示对args的第 0 个元素调用给定方法。- 参数
method_name (str) – 要应用于 self 参数的方法的名称。 例如,如果 args[0] 是表示
Tensor的Node,那么要对该Tensor调用relu(),请将relu传递给method_name。args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的位置参数。 请注意,这应该包括一个
self参数。kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的关键字参数。
type_expr (Optional[Any]) – 可选的类型注解,表示此节点的输出将具有的 Python 类型。
- 返回值
新创建并插入的
call_method节点。- 返回类型
注意
与
Graph.create_node()一样,此方法适用相同的插入点和类型表达式规则。注意
保证此 API 的向后兼容性。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[source][source]¶
插入一个
call_moduleNode到Graph中。一个call_module节点表示在Module层次结构中调用Module的 forward() 函数。- 参数
module_name (str) – 要调用的
Module在Module层次结构中的限定名称。 例如,如果跟踪的Module有一个名为foo的子模块,该子模块有一个名为bar的子模块,则应将限定名称foo.bar作为module_name传递以调用该模块。args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的位置参数。 请注意,这不应包括一个
self参数。kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的关键字参数。
type_expr (Optional[Any]) – 可选的类型注解,表示此节点的输出将具有的 Python 类型。
- 返回值
新创建并插入的
call_module节点。- 返回类型
注意
与
Graph.create_node()一样,此方法适用相同的插入点和类型表达式规则。注意
保证此 API 的向后兼容性。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source][source]¶
创建一个
Node并将其添加到Graph中的当前插入点。 请注意,可以通过Graph.inserting_before()和Graph.inserting_after()设置当前插入点。- 参数
op (str) – 此节点的 opcode。 为 'call_function'、'call_method'、'get_attr'、'call_module'、'placeholder' 或 'output' 之一。 这些 opcodes 的语义在
Graphdocstring 中描述。args (Optional[Tuple[Argument, ...]]) – 是此节点的参数元组。
kwargs (Optional[Dict[str, Argument]]) – 此节点的 kwargs。
name (Optional[str]) –
Node的可选字符串名称。 这将影响 Python 生成的代码中分配给值的名称。type_expr (Optional[Any]) – 可选的类型注解,表示此节点的输出将具有的 Python 类型。
- 返回值
新创建并插入的节点。
- 返回类型
注意
保证此 API 的向后兼容性。
- eliminate_dead_code(is_impure_node=None)[source][source]¶
根据每个节点的用户数量以及节点是否具有任何副作用,从图中删除所有无用代码。 在调用之前,必须对图进行拓扑排序。
- 参数
- 返回值
该 pass 是否导致图发生改变。
- 返回类型
示例
在消除无效代码之前,下面的 a = x + 1 中的 a 没有使用者,因此可以从图中消除,而不会产生影响。
def forward(self, x): a = x + 1 return x + self.attr_1
消除无效代码后,a = x + 1 已被移除,forward 的其余部分保留。
def forward(self, x): return x + self.attr_1
警告
无效代码消除有一些启发式方法来避免移除具有副作用的节点 (参见 Node.is_impure),但总的来说覆盖率非常差,因此您应该假定此方法调用是不安全的,除非您知道您的 FX 图完全由功能性操作组成,或者您提供自己的自定义函数来检测具有副作用的节点。
注意
保证此 API 的向后兼容性。
- erase_node(to_erase)[source][source]¶
从
Graph中擦除一个Node。如果该节点在Graph中仍然有使用者,则抛出异常。- 参数
to_erase (Node) – 要从
Graph中擦除的Node。
注意
保证此 API 的向后兼容性。
- find_nodes(*, op, target=None, sort=True)[source][source]¶
允许快速查询节点
- 参数
- 返回值
具有请求的 op 和 target 的节点的迭代器。
警告
此 API 尚处于实验阶段,并且不向后兼容。
- get_attr(qualified_name, type_expr=None)[source][source]¶
在 Graph 中插入一个
get_attr节点。get_attrNode表示从Module层次结构中获取一个属性。- 参数
qualified_name (str) – 要检索的属性的完全限定名称。例如,如果跟踪的 Module 有一个名为
foo的子模块,该子模块有一个名为bar的子模块,该子模块有一个名为baz的属性,则应将限定名称foo.bar.baz作为qualified_name传递。type_expr (Optional[Any]) – 可选的类型注解,表示此节点的输出将具有的 Python 类型。
- 返回值
新创建并插入的
get_attr节点。- 返回类型
注意
此方法与
Graph.create_node应用相同的插入点和类型表达式规则。注意
保证此 API 的向后兼容性。
- graph_copy(g, val_map, return_output_node=False)[source][source]¶
将给定图中的所有节点复制到
self中。- 参数
- 返回值
self中的值现在等同于g中的输出值,如果g有一个output节点。否则为None。- 返回类型
Optional[Union[tuple[‘Argument’, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
保证此 API 的向后兼容性。
- inserting_after(n=None)[source][source]¶
- 设置 create_node 和伴随方法将插入到图中的点。
在 'with' 语句中使用时,这将暂时设置插入点,然后在 with 语句退出时恢复它。
with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
参数
- n (Optional[Node]): 要在其后插入的节点。如果为 None,则将在整个图的开头之后插入。
整个图的开头。
- 返回值
一个资源管理器,它将在
__exit__上恢复插入点。
注意
保证此 API 的向后兼容性。
- inserting_before(n=None)[source][source]¶
- 设置 create_node 和伴随方法将插入到图中的点。
在 'with' 语句中使用时,这将暂时设置插入点,然后在 with 语句退出时恢复它。
with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
参数
- n (Optional[Node]): 要在其之前插入的节点。如果为 None,则将在之前插入
整个图的开头。
- 返回值
一个资源管理器,它将在
__exit__上恢复插入点。
注意
保证此 API 的向后兼容性。
- lint()[source][source]¶
对该图运行各种检查,以确保其格式正确。具体来说:- 检查节点是否具有正确的所有权(属于此图)- 检查节点是否按拓扑顺序出现 - 如果此图拥有 GraphModule,则检查目标是否存在于该 GraphModule 中
注意
保证此 API 的向后兼容性。
- node_copy(node, arg_transform=<function Graph.<lambda>>)[source][source]¶
将节点从一个图复制到另一个图。
arg_transform需要将参数从节点所在的图转换为 self 所在的图。 示例# Copying all the nodes in `g` into `new_graph` g: torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
- 参数
- 返回类型
注意
保证此 API 的向后兼容性。
- property nodes: _node_list¶
获取构成此图的节点列表。
请注意,此
Node列表表示形式是一个双向链表。 迭代期间的修改(例如,删除节点,添加节点)是安全的。- 返回值
节点的双向链表。 请注意,可以在此列表上调用
reversed以切换迭代顺序。
- on_generate_code(make_transformer)[source][source]¶
注册一个转换器函数,在生成python代码时使用
- 参数
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc])
一个返回要注册的代码转换器的函数。 该函数由 on_generate_code 调用以获取代码转换器。
此函数还将其当前注册的代码转换器(如果未注册任何内容,则为 None)作为其输入给出,以防不希望覆盖它。 这对于将代码转换器链接在一起很有用。
- 返回值
一个上下文管理器,当在 with 语句中使用时,会自动恢复先前注册的代码转换器。
示例
gm: fx.GraphModule = ... # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual # debugging with the PDB library. def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): gm.graph.on_generate_code(lambda _: insert_pdb) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb(current_trans(body) if current_trans else body) ) ) gm.recompile() gm(*inputs) # drops into pdb
此函数也可以用作上下文管理器,其优点是自动恢复先前注册的代码转换器
# ... continue from previous example with gm.graph.on_generate_code(lambda _: insert_pdb): # do more stuff with `gm`... gm.recompile() gm(*inputs) # drops into pdb # now previous code transformer is restored (but `gm`'s code with pdb # remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告
此 API 尚处于实验阶段,并且不向后兼容。
- output(result, type_expr=None)[source][source]¶
将
outputNode插入到Graph中。output节点表示 Python 代码中的return语句。result是应返回的值。- 参数
result (Argument) – 要返回的值。
type_expr (Optional[Any]) – 可选的类型注解,表示此节点的输出将具有的 Python 类型。
注意
此方法与
Graph.create_node应用相同的插入点和类型表达式规则。注意
保证此 API 的向后兼容性。
- placeholder(name, type_expr=None, default_value)[source][source]¶
将
placeholder节点插入到 Graph 中。placeholder表示函数输入。- 参数
name (str) – 输入值的名称。 这对应于此
Graph表示的函数的 positional 参数的名称。type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点的输出将具有的 Python 类型。 在某些情况下,这对于正确的代码生成是必需的(例如,当函数随后在 TorchScript 编译中使用时)。
default_value (Any) – 此函数参数应采用的默认值。 注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 作为此参数传递,以指定该参数_没有_默认值。
- 返回类型
注意
此方法与
Graph.create_node应用相同的插入点和类型表达式规则。注意
保证此 API 的向后兼容性。
- python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)[source][source]¶
将此
Graph转换为有效的 Python 代码。- 参数
root_module (str) – 用于查找限定名称目标的根模块的名称。 这通常是“self”。
- 返回值
src:表示对象的 Python 源代码 globals:src 中全局名称的字典 -> 它们引用的对象。
- 返回类型
一个 PythonCode 对象,由两个字段组成
注意
保证此 API 的向后兼容性。
- class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source][source]¶
Node是表示Graph中各个操作的数据结构。在大多数情况下,Nodes 表示对各种实体(如运算符、方法和模块)的调用位置(一些例外情况包括指定函数输入和输出的节点)。每个Node都有一个由其op属性指定的函数。op的每个值的Node语义如下:placeholder表示函数输入。name属性指定此值将采用的名称。target类似地是参数的名称。args包含:1) 空;或者 2) 表示函数输入的默认参数的单个参数。kwargs可以忽略。占位符对应于图形打印输出中的函数参数(例如,x)。get_attr从模块层次结构中检索参数。name类似地是将获取结果分配到的名称。target是参数在模块层次结构中的位置的完全限定名称。args和kwargs可以忽略。call_function将一个自由函数应用于某些值。name类似地是分配给值的名称。target是要应用的函数。args和kwargs表示函数的参数,遵循 Python 调用约定。call_module将模块层次结构中的模块的forward()方法应用于给定的参数。name与之前相同。target是要调用的模块在模块层次结构中的完全限定名称。args和kwargs表示调用模块的参数,不包括 self 参数。call_method调用值上的方法。name相似。target是应用于self参数的方法的字符串名称。args和kwargs表示调用模块的参数,包括 self 参数。output在其args[0]属性中包含跟踪函数的输出。 这对应于图形打印输出中的“return”语句。
注意
保证此 API 的向后兼容性。
- property all_input_nodes: list['Node']¶
返回作为此节点输入的全部节点。 这等效于迭代
args和kwargs,并且仅收集作为节点的值。- 返回值
出现在此
Node的args和kwargs中的Nodes列表,按该顺序排列。
- append(x)[source][source]¶
在此节点之后,在图形的节点列表中插入
x。 等效于self.next.prepend(x)- 参数
x (Node) – 要放在此节点之后的节点。 必须是同一图形的成员。
注意
保证此 API 的向后兼容性。
- property args: tuple[Union[tuple['Argument', ...], collections.abc.Sequence['Argument'], collections.abc.Mapping[str, 'Argument'], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...]¶
这个
Node的参数元组。参数的解释取决于节点的 opcode。 有关更多信息,请参阅Node文档字符串。允许对此属性进行赋值。 所有使用和用户的统计信息都会在赋值时自动更新。
- format_node(placeholder_names=None, maybe_return_typename=None)[source][source]¶
返回
self的描述性字符串表示形式。此方法可以不带参数使用,作为调试实用程序。
此函数也在
Graph的__str__方法内部使用。placeholder_names和maybe_return_typename中的字符串共同构成了此 Graph 的周围 GraphModule 中自动生成的forward函数的签名。 否则不应使用placeholder_names和maybe_return_typename。- 参数
- 返回值
- 如果 1) 我们将
format_node用作内部辅助函数 在
Graph的__str__方法中,并且 2)self是一个占位符节点,则返回None。 否则,返回当前节点的描述性字符串表示形式。
- 如果 1) 我们将
- 返回类型
注意
保证此 API 的向后兼容性。
- insert_arg(idx, arg)[source][source]¶
使用给定的索引将一个位置参数插入到参数列表中。
- 参数
idx (int) – 要在其之前插入元素的
self.args中的元素的索引。arg (Argument) – 要插入到
args中的新参数值
注意
保证此 API 的向后兼容性。
- is_impure()[source][source]¶
返回此操作是否为不纯的,即如果它的操作是占位符或输出,或者如果 call_function 或 call_module 是不纯的。
- 返回值
该操作是否为不纯。
- 返回类型
警告
此 API 尚处于实验阶段,并且不向后兼容。
- property kwargs: dict[str, Union[tuple['Argument', ...], collections.abc.Sequence['Argument'], collections.abc.Mapping[str, 'Argument'], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]¶
这个
Node的关键字参数的字典。参数的解释取决于节点的 opcode。 有关更多信息,请参阅Node文档字符串。允许对此属性进行赋值。 所有使用和用户的统计信息都会在赋值时自动更新。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source][source]¶
返回 Python 目标的标准化参数。 这意味着 args/kwargs 将与模块/函数的签名匹配,如果 normalize_to_only_use_kwargs 为真,则仅以位置顺序返回 kwargs。 还会填充默认值。 不支持仅位置参数或 varargs 参数。
支持模块调用。
可能需要 arg_types 和 kwarg_types 以消除重载的歧义。
- 参数
root (torch.nn.Module) – 用于解析模块目标的模块。
arg_types (Optional[Tuple[Any]]) – args 的 arg 类型元组
kwarg_types (Optional[Dict[str, Any]]) – kwargs 的 arg 类型字典
normalize_to_only_use_kwargs (bool) – 是否标准化为仅使用 kwargs。
- 返回值
如果成功,则返回 NamedTuple ArgsKwargsPair 或 None。
- 返回类型
Optional[ArgsKwargsPair]
警告
此 API 尚处于实验阶段,并且不向后兼容。
- prepend(x)[源代码][源代码]¶
将 x 插入到图中节点列表中的此节点之前。 示例
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- 参数
x (Node) – 要放在此节点之前的节点。 必须是同一图的成员。
注意
保证此 API 的向后兼容性。
- replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[源代码][源代码]¶
将图中
self的所有用法替换为节点replace_with。- 参数
- 返回值
进行此更改的节点列表。
- 返回类型
list[‘Node’]
注意
保证此 API 的向后兼容性。
- replace_input_with(old_input, new_input)[源代码][源代码]¶
遍历
self的输入节点,并将所有old_input实例替换为new_input。注意
保证此 API 的向后兼容性。
- property stack_trace: Optional[str]¶
返回在跟踪期间记录的 Python 堆栈跟踪(如果有)。 使用 fx.Tracer 跟踪时,此属性通常由 Tracer.create_proxy 填充。 要在跟踪期间记录堆栈跟踪以进行调试,请在 Tracer 实例上设置 record_stack_traces = True。 使用 dynamo 跟踪时,此属性将默认由 OutputGraph.create_proxy 填充。
stack_trace 的最内层帧将位于字符串的末尾。
- class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[源代码][源代码]¶
Tracer是实现torch.fx.symbolic_trace的符号跟踪功能的类。 调用symbolic_trace(m)等效于Tracer().trace(m)。可以对 Tracer 进行子类化以覆盖跟踪过程的各种行为。 可以在此类的方法的文档字符串中找到可覆盖的不同行为的描述。
注意
保证此 API 的向后兼容性。
- call_module(m, forward, args, kwargs)[源代码][源代码]¶
此方法指定
Tracer在遇到对nn.Module实例的调用时的行为。默认情况下,该行为是检查被调用的模块是否是叶模块,通过
is_leaf_module。 如果是,则发出一个call_module节点,该节点在Graph中引用m。 否则,正常调用Module,跟踪其forward函数中的操作。可以覆盖此方法来创建嵌套的跟踪 GraphModule,或者在跨
Module边界进行跟踪时所需的任何其他行为。- 参数
m (Module) – 要为其发出调用的模块
forward (Callable) – 要调用的
Module的 forward() 方法args (Tuple) – 模块调用站点的 args
kwargs (Dict) – 模块调用站点的 kwargs
- 返回值
Module 调用的返回值。 在发出
call_module节点的情况下,这是一个Proxy值。 否则,它是从Module调用返回的任何值。- 返回类型
注意
保证此 API 的向后兼容性。
- create_arg(a)[源代码][源代码]¶
指定在准备用作
Graph中节点参数的值时,追踪行为的方法。默认行为包括:
遍历集合类型(例如,tuple,list,dict),并在元素上递归调用
create_args。给定一个 Proxy 对象,返回对底层 IR
Node的引用。给定一个非 Proxy Tensor 对象,为各种情况发出 IR。
对于 Parameter,发出一个引用该 Parameter 的
get_attr节点。对于非 Parameter Tensor,将 Tensor 存储在一个特殊的属性中,该属性引用该属性。
此方法可以被覆盖以支持更多类型。
- 参数
a (Any) – 要作为
Graph中的Argument发出的值。- 返回值
值
a转换为适当的Argument- 返回类型
Argument
注意
保证此 API 的向后兼容性。
- create_args_for_root(root_fn, is_module, concrete_args=None)[source][source]¶
创建与
rootModule 的签名相对应的placeholder节点。此方法内省 root 的签名并相应地发出这些节点,也支持*args和**kwargs。警告
此 API 尚处于实验阶段,并且不向后兼容。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]¶
根据 target、args、kwargs 和 name 插入一个 graph 节点。
此方法可以被覆盖以对节点创建中使用的值进行额外的检查、验证或修改。例如,可能希望禁止记录原地操作。
注意
保证此 API 的向后兼容性。
- 返回类型
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]¶
从给定的参数创建一个 Node,然后返回包装在 Proxy 对象中的 Node。
如果 kind = ‘placeholder’,那么我们正在创建一个表示函数参数的 Node。如果我们需要编码一个默认参数,我们使用
args元组。args对于placeholder节点通常为空。注意
保证此 API 的向后兼容性。
- get_fresh_qualname(prefix)[source][source]¶
获取一个 prefix 的新名称并返回它。此函数确保它不会与图上的现有属性冲突。
注意
保证此 API 的向后兼容性。
- 返回类型
- getattr(attr, attr_val, parameter_proxy_cache)[source][source]¶
指定此
Tracer在我们对nn.Module实例的调用调用 getattr 时的行为的方法。默认情况下,该行为是返回属性的代理值。它还将代理值存储在
parameter_proxy_cache中,以便将来的调用将重用代理而不是创建一个新的代理。可以覆盖此方法,例如,在查询参数时不返回代理。
- 参数
- 返回值
getattr 调用的返回值。
警告
此 API 尚处于实验阶段,并且不向后兼容。
- is_leaf_module(m, module_qualified_name)[source][source]¶
指定给定
nn.Module是否为“叶子”模块的方法。叶子模块是出现在 IR 中的原子单元,由
call_module调用引用。默认情况下,PyTorch 标准库命名空间 (torch.nn) 中的模块是叶子模块。所有其他模块都将被追踪,它们的组成操作将被记录,除非通过此参数另有规定。- 参数
- 返回类型
注意
保证此 API 的向后兼容性。
- iter(obj)[source]¶
- 当代理对象被迭代时调用,例如
当在控制流中使用时。通常我们不知道该怎么做,因为我们不知道代理的值,但是自定义追踪器可以使用 create_node 将更多信息附加到 graph 节点,并且可以选择返回一个迭代器。
注意
保证此 API 的向后兼容性。
- 返回类型
- keys(obj)[source]¶
- 当代理对象调用了 keys() 方法时调用。
这是当 ** 在代理上调用时发生的事情。如果 ** 应该在你的自定义追踪器中工作,则应该返回一个迭代器。
注意
保证此 API 的向后兼容性。
- 返回类型
- path_of_module(mod)[source][source]¶
辅助方法,用于查找
mod在root的模块层次结构中的限定名称。 例如,如果root有一个名为foo的子模块,该子模块有一个名为bar的子模块,则将bar传递给此函数将返回字符串“foo.bar”。注意
保证此 API 的向后兼容性。
- to_bool(obj)[source]¶
- 当代理对象被转换为布尔值时调用,例如
在控制流中使用时。 通常我们不知道该怎么做,因为我们不知道代理的值,但是自定义追踪器可以使用 create_node 将更多信息附加到图节点,并且可以选择返回一个值。
注意
保证此 API 的向后兼容性。
- 返回类型
- class torch.fx.Proxy(node, tracer=None)[source][source]¶
Proxy对象是Node包装器,在符号追踪期间流经程序,并将它们接触到的所有操作(torch函数调用、方法调用、运算符)记录到增长的 FX 图中。如果您正在进行图转换,您可以将您自己的
Proxy方法包装在一个原始的Node周围,以便您可以使用重载的运算符向Graph添加其他内容。Proxy对象无法迭代。 换句话说,如果Proxy在循环中使用或用作*args/**kwargs函数参数,则符号追踪器将抛出错误。有两种主要方法可以解决这个问题:1. 将无法追踪的逻辑分解为顶层函数,并在其上使用
fx.wrap。 2. 如果控制流是静态的(即,循环行程计数基于某些超参数),则可以将代码保留在其原始位置并重构为类似的内容for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
有关 Proxy 内部结构的更详细说明,请查看 torch/fx/README.md 中的“Proxy”部分
注意
保证此 API 的向后兼容性。
- class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[source][source]¶
Interpreter逐个节点地执行FX图。 这种模式对于许多事情都很有用,包括编写代码转换以及分析pass。
可以覆盖 Interpreter 类中的方法以自定义执行行为。 调用层次结构中的可覆盖方法映射
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
示例
假设我们想将所有
torch.neg实例与torch.sigmoid互换,反之亦然(包括它们的Tensor方法等效项)。 我们可以像这样子类化 Interpreterclass NegSigmSwapInterpreter(Interpreter): def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid())
- 参数
module (torch.nn.Module) – 要执行的模块
garbage_collect_values (bool) – 是否在模块执行中最后一次使用值后删除它们。 这确保了执行期间的最佳内存使用。 可以禁用此功能,例如,通过查看
Interpreter.env属性来检查执行中的所有中间值。graph (Optional[Graph]) – 如果传递,解释器将执行此图而不是 module.graph,并使用提供的 module 参数来满足任何状态请求。
注意
保证此 API 的向后兼容性。
- boxed_run(args_list)[source][source]¶
通过解释运行 module 并返回结果。 这使用“boxed”调用约定,您在其中传递一个参数列表,该列表将被解释器清除。 这确保了及时释放输入张量。
注意
保证此 API 的向后兼容性。
- call_function(target, args, kwargs)[source][source]¶
执行一个
call_function节点并返回结果。- 参数
target (Target) – 此节点的调用目标。 有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数的元组
kwargs (Dict) – 此调用的关键字参数的字典
- 返回类型
- 返回
Any: 函数调用返回的值
注意
保证此 API 的向后兼容性。
- call_method(target, args, kwargs)[source][source]¶
执行一个
call_method节点并返回结果。- 参数
target (Target) – 此节点的调用目标。 有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数的元组
kwargs (Dict) – 此调用的关键字参数的字典
- 返回类型
- 返回
Any: 方法调用返回的值
注意
保证此 API 的向后兼容性。
- call_module(target, args, kwargs)[source][source]¶
执行一个
call_module节点并返回结果。- 参数
target (Target) – 此节点的调用目标。 有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数的元组
kwargs (Dict) – 此调用的关键字参数的字典
- 返回类型
- 返回
Any: 模块调用返回的值
注意
保证此 API 的向后兼容性。
- fetch_args_kwargs_from_env(n)[source][source]¶
从当前执行环境中获取节点
n的args和kwargs的具体值。- 参数
n (Node) – 需要获取
args和kwargs的节点。- 返回值
n的具有具体值的args和kwargs。- 返回类型
Tuple[Tuple, Dict]
注意
保证此 API 的向后兼容性。
- fetch_attr(target)[source][source]¶
从
self.module的Module层次结构中获取一个属性。- 参数
target (str) – 要获取的属性的完全限定名称
- 返回值
属性的值。
- 返回类型
Any
注意
保证此 API 的向后兼容性。
- get_attr(target, args, kwargs)[source][source]¶
执行一个
get_attr节点。 将从self.module的Module层次结构中检索属性值。- 参数
target (Target) – 此节点的调用目标。 有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数的元组
kwargs (Dict) – 此调用的关键字参数的字典
- 返回值
检索到的属性的值
- 返回类型
Any
注意
保证此 API 的向后兼容性。
- map_nodes_to_values(args, n)[source][source]¶
递归地遍历
args,并在当前执行环境中查找每个Node的具体值。- 参数
args (Argument) – 在其中查找具体值的数据结构
n (Node) –
args所属的节点。 这仅用于错误报告。
- 返回类型
Optional[Union[tuple[‘Argument’, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
保证此 API 的向后兼容性。
- output(target, args, kwargs)[source][source]¶
执行一个
output节点。 实际上,这只是检索output节点引用的值并返回它。- 参数
target (Target) – 此节点的调用目标。 有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数的元组
kwargs (Dict) – 此调用的关键字参数的字典
- 返回值
输出节点引用的返回值
- 返回类型
Any
注意
保证此 API 的向后兼容性。
- placeholder(target, args, kwargs)[source][source]¶
执行一个
placeholder节点。 请注意,这是有状态的:Interpreter维护传递给run的参数的内部迭代器,并且此方法返回该迭代器的 next()。- 参数
target (Target) – 此节点的调用目标。 有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数的元组
kwargs (Dict) – 此调用的关键字参数的字典
- 返回值
检索到的参数值。
- 返回类型
Any
注意
保证此 API 的向后兼容性。
- class torch.fx.Transformer(module)[source][source]¶
Transformer是一种特殊类型的解释器,可生成新的Module。 它公开了一个transform()方法,该方法返回转换后的Module。Transformer不需要参数来运行,就像Interpreter一样。Transformer完全以符号方式工作。示例
假设我们要将
torch.neg的所有实例与torch.sigmoid交换,反之亦然(包括它们的Tensor方法等效项)。 我们可以像这样对Transformer进行子类化class NegSigmSwapXformer(Transformer): def call_function( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] ) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- 参数
module (GraphModule) – 要转换的
Module。
注意
保证此 API 的向后兼容性。
- get_attr(target, args, kwargs)[source][source]¶
执行一个
get_attr节点。 在Transformer中,这将重写为将新的get_attr节点插入到输出图中。- 参数
target (Target) – 此节点的调用目标。 有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数的元组
kwargs (Dict) – 此调用的关键字参数的字典
- 返回类型
注意
保证此 API 的向后兼容性。
- placeholder(target, args, kwargs)[source][source]¶
执行一个
placeholder节点。在Transformer中,它被重写以在输出图中插入一个新的placeholder。- 参数
target (Target) – 此节点的调用目标。 有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数的元组
kwargs (Dict) – 此调用的关键字参数的字典
- 返回类型
注意
保证此 API 的向后兼容性。
- torch.fx.replace_pattern(gm, pattern, replacement)[source][source]¶
在 GraphModule (
gm) 的 Graph 中匹配所有可能的非重叠运算符及其数据依赖项集合 (pattern),然后用另一个子图 (replacement) 替换每个匹配的子图。- 参数
gm (GraphModule) – 包装要操作的 Graph 的 GraphModule
pattern (Union[Callable, GraphModule]) – 要在
gm中匹配以进行替换的子图replacement (Union[Callable, GraphModule]) – 用于替换
pattern的子图
- 返回值
一个
Match对象列表,表示原始图中与pattern匹配的位置。如果没有匹配项,则列表为空。Match定义为class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
- 返回类型
List[Match]
示例
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]) def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上面的代码将首先在
traced_module的forward方法中匹配pattern。模式匹配是基于使用-定义关系完成的,而不是基于节点名称。 例如,如果pattern中有p = torch.cat([a, b]),则可以在原始forward函数中匹配m = torch.cat([a, b]),即使变量名称不同(p与m)。pattern中的return语句仅根据其值进行匹配;它可能匹配也可能不匹配到较大图中的return语句。换句话说,该模式不必扩展到较大图的末尾。当模式匹配时,它将从较大的函数中删除,并由
replacement替换。如果较大函数中存在pattern的多个匹配项,则每个非重叠的匹配项将被替换。 在匹配重叠的情况下,将替换重叠匹配项集中找到的第一个匹配项。(“第一个”在此定义为节点使用-定义关系拓扑排序中的第一个。在大多数情况下,第一个节点是直接出现在self之后的参数,而最后一个节点是函数返回的任何内容。)需要注意的一个重要事项是,
patternCallable 的参数必须在 Callable 本身中使用,并且replacementCallable 的参数必须与模式匹配。 第一个规则是为什么在上面的代码块中,forward函数具有参数x, w1, w2,而pattern函数仅具有参数w1, w2。pattern不使用x,因此不应将x指定为参数。 作为第二个规则的示例,请考虑替换def pattern(x, y): return torch.neg(x) + torch.relu(y)
为
def replacement(x, y): return torch.relu(x)
在这种情况下,即使参数
y未在replacement中使用,replacement也需要与pattern相同数量的参数(x和y)。调用
subgraph_rewriter.replace_pattern后,生成的 Python 代码如下所示def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
保证此 API 的向后兼容性。