torch.fx¶
概述¶
FX 是一个工具包,供开发者用来转换 nn.Module
实例。FX 包含三个主要组件:符号跟踪器、中间表示和Python 代码生成。这些组件的实际应用示例
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符号跟踪器对 Python 代码执行“符号执行”。它将称为代理的虚假值馈送到代码中。记录对这些代理的操作。有关符号跟踪的更多信息,请参阅 symbolic_trace()
和 Tracer
文档。
中间表示是符号跟踪期间记录的操作的容器。它包含一个节点列表,这些节点表示函数输入、调用点(到函数、方法或 torch.nn.Module
实例)以及返回值。有关 IR 的更多信息,请参阅 Graph
的文档。IR 是应用转换的格式。
Python 代码生成是使 FX 成为 Python 到 Python(或模块到模块)转换工具包的原因。对于每个 Graph IR,我们可以创建与 Graph 语义匹配的有效 Python 代码。此功能封装在 GraphModule
中,它是一个 torch.nn.Module
实例,它包含一个 Graph
以及从 Graph 生成的 forward
方法。
总而言之,这个组件管道(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。此外,这些组件可以单独使用。例如,符号跟踪可以单独使用来捕获代码的一种形式以进行分析(而不是转换)目的。代码生成可用于以编程方式生成模型,例如从配置文件生成模型。FX 有很多用途!
可以在 示例 存储库中找到一些示例转换。
编写转换¶
什么是 FX 转换?本质上,它是一个看起来像这样的函数。
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
您的转换将接收一个 torch.nn.Module
,从中获取一个 Graph
,进行一些修改,并返回一个新的 torch.nn.Module
。您应该将 FX 转换返回的 torch.nn.Module
视为与常规 torch.nn.Module
相同 - 您可以在另一个 FX 转换中传递它,可以将其传递给 TorchScript,也可以运行它。确保 FX 转换的输入和输出是 torch.nn.Module
将允许可组合性。
注意
也可以修改现有的 GraphModule
而不是创建一个新的,如下所示
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
请注意,您必须调用 GraphModule.recompile()
以使生成的 forward()
方法与修改后的 Graph
同步。
鉴于您已经传递了一个被跟踪到 Graph
的 torch.nn.Module
,现在您可以采用两种主要方法来构建新的 Graph
。
关于图的快速入门¶
可以在 Graph
文档中找到对图语义的完整处理,但我们将在本文中介绍基础知识。一个 Graph
是一个数据结构,它表示 GraphModule
上的方法。这需要的信息是
方法的输入是什么?
方法内部运行的操作是什么?
方法的输出(即返回值)是什么?
这三个概念都用 Node
实例表示。让我们看一个简短的例子来了解这意味着什么。
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
在这里,我们定义了一个名为 MyModule
的模块用于演示目的,实例化它,对其进行符号跟踪,然后调用 Graph.print_tabular()
方法来打印一个表格,显示此 Graph
的节点。
操作码
名称
目标
参数
关键字参数
占位符
x
x
()
{}
获取属性
linear_weight
linear.weight
()
{}
调用函数
add_1
<内置函数 add>
(x, linear_weight)
{}
调用模块
linear_1
linear
(add_1,)
{}
调用方法
relu_1
relu
(linear_1,)
{}
调用函数
sum_1
<内置方法 sum …>
(relu_1,)
{'dim': -1}
调用函数
topk_1
<内置方法 topk …>
(sum_1, 3)
{}
output
output
output
(topk_1,)
{}
我们可以利用这些信息来回答我们之前提出的问题。
方法的输入是什么?在 FX 中,方法输入通过特殊的
placeholder
节点指定。在本例中,我们有一个placeholder
节点,其target
为x
,这意味着我们有一个名为 x 的单一(非自身)参数。方法内部有哪些操作?
get_attr
、call_function
、call_module
和call_method
节点代表方法中的操作。所有这些语义的完整处理可以在Node
文档中找到。方法的返回值是什么?
Graph
中的返回值由一个特殊的output
节点指定。
既然我们现在了解了代码在 FX 中如何表示的基本知识,我们现在可以探索如何编辑 Graph
。
图操作¶
直接图操作¶
构建这个新的 Graph
的一种方法是直接操作旧的图。为了帮助我们做到这一点,我们可以简单地获取从符号跟踪中获得的 Graph
并对其进行修改。例如,假设我们希望用 torch.mul()
调用替换 torch.add()
调用。
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
我们还可以进行更复杂的 Graph
重写,例如删除或追加节点。为了帮助进行这些转换,FX 提供了一些用于转换图的实用函数,这些函数可以在 Graph
文档中找到。下面是一个使用这些 API 追加 torch.relu()
调用的示例。
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
对于仅包含替换的简单转换,您还可以使用 子图重写器。
使用 replace_pattern() 进行子图重写¶
FX 还提供了一种基于直接图形操作的自动化级别。 replace_pattern()
API 本质上是一个用于编辑 Graph
的“查找/替换”工具。它允许您指定一个 pattern
和 replacement
函数,它将跟踪这些函数,找到 pattern
图中操作组的实例,并将这些实例替换为 replacement
图的副本。这可以帮助您极大地自动化繁琐的图形操作代码,因为随着转换变得更加复杂,这些代码会变得难以管理。
代理/重新跟踪¶
另一种操作 Graph
的方法是重用符号跟踪中使用的 Proxy
机制。例如,假设我们想要编写一个将 PyTorch 函数分解成更小操作的转换。它将把每个 F.relu(x)
调用转换为 (x > 0) * x
。一种可能性是在 F.relu
之后执行必要的图形重写以插入比较和乘法,然后清理原始的 F.relu
。但是,我们可以使用 Proxy
对象来自动将操作记录到 Graph
中,从而自动化此过程。
要使用此方法,我们将要插入的操作编写为常规 PyTorch 代码,并使用 Proxy
对象作为参数调用该代码。这些 Proxy
对象将捕获对它们执行的操作并将它们追加到 Graph
中。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免显式图形操作之外,使用 Proxy
还允许您将重写规则指定为原生 Python 代码。对于需要大量重写规则的转换(例如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。请注意,在调用 Proxy
时,我们还传递了一个指向基础变量 graph 的跟踪器。这样做是为了防止图形中的操作是 n 元的(例如,add 是一个二元运算符),对 Proxy
的调用不会创建图形跟踪器的多个实例,这会导致意外的运行时错误。我们建议使用这种使用 Proxy
的方法,尤其是在无法安全地假设基础运算符为一元运算符时。
解释器模式¶
在 FX 中,一个有用的代码组织模式是循环遍历 Node
在 Graph
中执行它们。这可以用于多种用途,包括对流经图的值进行运行时分析或通过使用 Proxy
进行代码重跟踪来转换代码。例如,假设我们想要运行一个 GraphModule
并记录 torch.Tensor
在运行时看到节点上的形状和数据类型属性。这可能看起来像
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
如您所见,FX 的完整解释器并不复杂,但它非常有用。为了简化使用这种模式,我们提供了 Interpreter
类,它以一种可以覆盖解释器执行的某些方面的方式包含了上述逻辑,方法是通过方法覆盖。
除了执行操作之外,我们还可以通过将 Proxy
值馈送到解释器来生成一个新的 Graph。类似地,我们提供了 Transformer
类来包含这种模式。 Transformer
的行为类似于 Interpreter
,但不是调用 run
方法从模块获取具体输出值,而是调用 Transformer.transform()
方法返回一个新的 GraphModule
,该模块已受到您作为覆盖方法安装的任何转换规则的影响。
调试¶
介绍¶
在编写转换的过程中,我们的代码可能并不完全正确。在这种情况下,我们可能需要进行一些调试。关键是反向工作:首先,检查调用生成的模块的结果以证明或反驳正确性。然后,检查和调试生成的代码。然后,调试导致生成代码的转换过程。
如果您不熟悉调试器,请参阅辅助部分 可用调试器.
转换编写中的常见陷阱¶
非确定性
set
迭代顺序。在 Python 中,set
数据类型是无序的。例如,使用set
来包含像Node
这样的对象集合会导致意外的非确定性。一个例子是迭代一组Node
以将它们插入到Graph
中。因为set
数据类型是无序的,所以输出程序中操作的顺序将是非确定性的,并且可以在程序调用之间发生变化。建议的替代方法是使用dict
数据类型,该类型是 插入有序的,从 Python 3.7(以及从 cPython 3.6)开始。一个dict
可以等效地用作一个集合,通过将要进行重复数据删除的值存储在dict
的键中。
检查模块的正确性¶
由于大多数深度学习模块的输出由浮点数 torch.Tensor
实例组成,因此检查两个 torch.nn.Module
结果之间的等效性并不像简单地进行相等性检查那样简单。为了说明这一点,让我们举个例子
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在这里,我们尝试使用 ==
等号运算符检查两个深度学习模型的值是否相等。然而,这并不明确,因为该运算符返回的是张量而不是布尔值,而且由于浮点数运算的非交换性,浮点数值的比较应该使用误差范围(或 epsilon)来进行(有关更多详细信息,请参见 这里)。我们可以使用 torch.allclose()
来代替,它将提供一个近似比较,并考虑相对和绝对容差阈值
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
这是我们工具箱中第一个工具,用于检查转换后的模块与参考实现相比是否按预期工作。
调试生成的代码¶
由于 FX 在 GraphModule
上生成 forward()
函数,因此使用传统的调试技术(如 print
语句或 pdb
)并不像以前那样简单。幸运的是,我们有几种技术可以用来调试生成的代码。
使用 pdb
¶
调用 pdb
以进入正在运行的程序。虽然表示 Graph
的代码不在任何源文件中,但我们仍然可以使用 pdb
在调用前向传递时手动进入它。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
打印生成的代码¶
如果您想多次运行相同的代码,那么使用 pdb
一步步调试到正确的代码可能会很繁琐。在这种情况下,一种方法是简单地将生成的 forward
代码复制粘贴到您的代码中,然后从那里进行检查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 GraphModule
中的 to_folder
函数¶
GraphModule.to_folder()
是 GraphModule
中的一个方法,它允许您将生成的 FX 代码转储到一个文件夹中。虽然将 forward 代码复制到代码中通常就足够了,如 打印生成的代码 所示,但使用 to_folder
检查模块和参数可能更容易。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
运行上面的示例后,我们可以查看 foo/module.py
中的代码,并根据需要进行修改(例如,添加 print
语句或使用 pdb
)来调试生成的代码。
调试转换¶
现在我们已经确定转换生成了错误的代码,是时候调试转换本身了。首先,我们将检查文档中的 符号追踪的局限性 部分。一旦我们验证了追踪按预期工作,目标就变成了弄清楚我们的 GraphModule
转换过程中出了什么问题。在 编写转换 中可能会有一个快速的答案,但如果没有,有几种方法可以检查我们追踪的模块。
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
output output output (add,) {}
"""
使用上面的实用函数,我们可以比较我们在应用转换之前和之后追踪的模块。有时,简单的视觉比较足以追踪到错误。如果仍然不清楚出了什么问题,像 pdb
这样的调试器可能是下一步的良好选择。
参考上面的例子,考虑以下代码
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
使用上面的例子,假设对 print(traced)
的调用显示我们的转换中存在错误。我们想使用调试器找出问题所在。我们启动一个 pdb
会话。我们可以通过在 transform_graph(traced)
上断点,然后按 s
来“单步进入”对 transform_graph(traced)
的调用来查看转换期间发生了什么。
我们也可以通过编辑 print_tabular
方法来打印图中节点的不同属性,从而获得好运。(例如,我们可能想查看节点的 input_nodes
和 users
。)
可用的调试器¶
The most common Python debugger is
pdb. You can start
your program in “debug mode” with pdb
by typing
python -m pdb FILENAME.py
into the command line, where FILENAME
is the name of the file you want to debug. After that, you can use the
pdb
debugger commands
to move through your running program stepwise. It’s common to set a
breakpoint (b LINE-NUMBER
) when you start pdb
, then call c
to
run the program until that point. This prevents you from having to step
through each line of execution (using s
or n
) to get to the part
of the code you want to examine. Alternatively, you can write
import pdb; pdb.set_trace()
before the line you want to break at.
If you add pdb.set_trace()
, your program will automatically start
in debug mode when you run it. (In other words, you can just type
python FILENAME.py
into the command line instead of
python -m pdb FILENAME.py
.) Once you’re running your file in
debug mode, you can step through the code and examine your program’s
internal state using certain commands. There are many excellent
tutorials on pdb
online, including RealPython’s
“Python Debugging With Pdb”.
像 PyCharm 或 VSCode 这样的 IDE 通常内置了调试器。在你的 IDE 中,你可以选择 a) 使用 pdb
,方法是在你的 IDE 中打开一个终端窗口(例如,在 VSCode 中选择“视图”→“终端”),或者 b) 使用内置的调试器(通常是 pdb
的图形包装器)。
符号跟踪的局限性¶
FX 使用一个 **符号跟踪** 系统(也称为 符号执行)来捕获程序语义的可转换/可分析形式。该系统是 **跟踪** 的,因为它执行程序(实际上是一个 torch.nn.Module
或函数)来记录操作。它是 **符号** 的,因为在执行过程中流经程序的数据不是真实数据,而是符号(在 FX 中称为 Proxy
)。
虽然符号跟踪适用于大多数神经网络代码,但它也有一些局限性。
动态控制流¶
符号追踪的主要限制是它目前不支持*动态控制流*。也就是说,循环或if
语句,其中条件可能取决于程序的输入值。
例如,让我们检查以下程序
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() > 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
if
语句的条件依赖于x.sum()
的值,而x.sum()
的值又依赖于x
的值,这是一个函数输入。由于x
可以改变(例如,如果将新的输入张量传递给被追踪的函数),这就是*动态控制流*。回溯会向上遍历你的代码,以显示这种情况发生的位置。
静态控制流¶
另一方面,所谓的*静态控制流*是支持的。静态控制流是循环或if
语句,其值在调用之间不会改变。通常,在 PyTorch 程序中,这种控制流是针对根据超参数做出模型架构决策的代码而产生的。作为一个具体的例子
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if 语句if self.do_activation
不依赖于任何函数输入,因此它是静态的。do_activation
可以被认为是一个超参数,并且具有不同参数值的MyModule
的不同实例的跟踪具有不同的代码。这是一个有效的模式,由符号追踪支持。
许多动态控制流的实例在语义上是静态控制流。这些实例可以通过消除对输入值的依赖关系来支持符号追踪,例如,将值移动到Module
属性,或在符号追踪期间将具体值绑定到参数。
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
在真正动态控制流的情况下,包含此代码的程序部分可以被追踪为对方法的调用(参见使用 Tracer 类自定义追踪)或函数(参见wrap()
),而不是通过它们进行追踪。
非-torch
函数¶
FX 使用 __torch_function__
作为其拦截调用的机制(有关此机制的更多信息,请参阅 技术概述)。某些函数,例如内置 Python 函数或 math
模块中的函数,不受 __torch_function__
涵盖,但我们仍然希望在符号跟踪中捕获它们。例如
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
错误告诉我们,内置函数 len
不受支持。我们可以使用 wrap()
API 使此类函数在跟踪中作为直接调用被记录下来。
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用 Tracer
类自定义跟踪¶
Tracer
类是 symbolic_trace
实现的基础类。可以通过子类化 Tracer 来自定义跟踪的行为,如下所示
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
叶模块¶
叶模块是在符号跟踪中显示为调用而不是被跟踪的模块。默认的叶模块集是标准 torch.nn
模块实例集。例如
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
可以通过覆盖 Tracer.is_leaf_module()
来自定义叶模块集。
杂项¶
张量构造函数(例如
torch.zeros
、torch.ones
、torch.rand
、torch.randn
、torch.sparse_coo_tensor
)目前不可跟踪。确定性构造函数(
zeros
,ones
)可以使用,它们产生的值将作为常量嵌入到跟踪中。只有当这些构造函数的参数引用动态输入大小时,这才会出现问题。在这种情况下,ones_like
或zeros_like
可能是一个可行的替代方案。非确定性构造函数(
rand
,randn
)将在跟踪中嵌入单个随机值。这可能不是预期的行为。一种解决方法是将torch.randn
包裹在一个torch.fx.wrap
函数中,并调用该函数。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
此行为可能会在将来的版本中修复。
类型注释
支持 Python 3 风格的类型注释(例如
func(x : torch.Tensor, y : int) -> torch.Tensor
),并且符号跟踪将保留它们。目前不支持 Python 2 风格的注释类型注释
# type: (torch.Tensor, int) -> torch.Tensor
。目前不支持函数内局部名称的注释。
关于
training
标志和子模块的注意事项当使用诸如
torch.nn.functional.dropout
之类的函数时,训练参数通常作为self.training
传入。在 FX 跟踪期间,这很可能被烘焙为一个常量值。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! Mismatched elements: 15 / 15 (100.0%) Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) """
但是,当使用标准
nn.Dropout()
子模块时,训练标志被封装,并且由于nn.Module
对象模型的保留,可以更改。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
由于这种差异,请考虑将与
training
标志动态交互的模块标记为叶模块。
API 参考¶
- torch.fx.symbolic_trace(root, concrete_args=None)[source]¶
符号跟踪 API
给定一个
nn.Module
或函数实例root
,此函数将返回一个GraphModule
,该模块通过跟踪root
中看到的操作来构建。concrete_args
允许您对函数进行部分特化,无论是为了移除控制流还是数据结构。例如
def f(a, b): if b == True: return a else: return a*2
由于存在控制流,FX 通常无法跟踪此代码。但是,我们可以使用 concrete_args 对 b 的值进行特化,以跟踪此代码
f = fx.symbolic_trace(f, concrete_args={'b': False}) assert f(3, False) == 6
请注意,虽然您仍然可以传入不同的 b 值,但它们将被忽略。
我们还可以使用 concrete_args 从函数中消除数据结构处理。这将使用 pytrees 来扁平化您的输入。为了避免过度特化,请为不应特化的值传入 fx.PH。例如
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) assert f({'a': 1, 'b': 2, 'c': 4}) == 7
- 参数
root (Union[torch.nn.Module, Callable]) – 要跟踪并转换为图形表示的模块或函数。
concrete_args (Optional[Dict[str, any]]) – 要部分特化的输入
- 返回值
从
root
中记录的操作创建的模块。- 返回类型
注意
此 API 的向后兼容性得到保证。
- torch.fx.wrap(fn_or_name)[source]¶
此函数可以在模块级作用域中调用,以将 fn_or_name 注册为“叶函数”。“叶函数”将在 FX 跟踪中保留为 CallFunction 节点,而不是被跟踪。
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap('my_custom_function') def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
此函数也可以等效地用作装饰器。
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
包装函数可以被认为是“叶函数”,类似于“叶模块”的概念,即它们是作为 FX 跟踪中的调用保留的函数,而不是被跟踪的函数。
- 参数
fn_or_name (Union[str, Callable]) – 当调用时,要插入图中的函数或全局函数的名称。
注意
此 API 的向后兼容性得到保证。
- class torch.fx.GraphModule(*args, **kwargs)[source]¶
GraphModule 是从 fx.Graph 生成的 nn.Module。Graphmodule 具有
graph
属性,以及从该graph
生成的code
和forward
属性。警告
当
graph
被重新分配时,code
和forward
将被自动重新生成。但是,如果您在不重新分配graph
属性本身的情况下编辑graph
的内容,则必须调用recompile()
来更新生成的代码。注意
此 API 的向后兼容性得到保证。
- __init__(root, graph, class_name='GraphModule')[source]¶
构造一个 GraphModule。
- 参数
root (Union[torch.nn.Module, Dict[str, Any]) –
root
可以是 nn.Module 实例或一个将字符串映射到任何属性类型的字典。如果root
是一个模块,则 Graph 中的节点的target
字段中对基于模块的对象(通过限定名称)的任何引用将从root
的模块层次结构中的相应位置复制到 GraphModule 的模块层次结构中。如果root
是一个字典,则在节点的target
中找到的限定名称将直接在字典的键中查找。字典映射到的对象将被复制到 GraphModule 的模块层次结构中的相应位置。graph (Graph) –
graph
包含此 GraphModule 应用于代码生成的节点class_name (str) –
name
表示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息都将报告为源自GraphModule
。将此设置为root
的原始名称或在您的转换上下文中具有意义的名称可能会有所帮助。
注意
此 API 的向后兼容性得到保证。
- add_submodule(target, m)[source]¶
将给定的子模块添加到
self
中。如果子模块是
target
的子路径,则会安装空的模块,即使它们不存在。- 参数
- 返回值
- 子模块是否可以插入。对于
此方法要返回 True,
target
所表示的链中的每个对象必须满足以下条件之一:a) 尚未存在,或 b) 引用nn.Module
(而不是参数或其他属性)。
- 返回类型
注意
此 API 的向后兼容性得到保证。
- delete_all_unused_submodules()[source]¶
从
self
中删除所有未使用的子模块。如果满足以下条件之一,则模块被视为“已使用”:1. 它具有已使用的子模块 2. 它的 forward 方法通过
call_module
节点直接调用 3. 它具有一个非模块属性,该属性从get_attr
节点使用。此方法可以用来清理一个
nn.Module
,而无需手动调用每个未使用的子模块的delete_submodule
。注意
此 API 的向后兼容性得到保证。
- delete_submodule(target)[source]¶
从
self
中删除给定的子模块。如果
target
不是有效的目标,则不会删除该模块。- 参数
target (str) – 新子模块的完全限定字符串名称(有关如何指定完全限定字符串的示例,请参见
nn.Module.get_submodule
)。- 返回值
- 目标字符串是否引用了
我们要删除的子模块。返回值为
False
表示target
不是对子模块的有效引用。
- 返回类型
注意
此 API 的向后兼容性得到保证。
- print_readable(print_output=True)[source]¶
返回为当前 GraphModule 及其子 GraphModule 生成的 Python 代码。
警告
此 API 处于实验阶段,不 向后兼容。
- class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source]¶
Graph
是 FX 中间表示中使用的主要数据结构。它由一系列Node
组成,每个节点代表调用站点(或其他语法结构)。Node
的列表共同构成一个有效的 Python 函数。例如,以下代码
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = torch.fx.symbolic_trace(m)
将生成以下 Graph
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1
有关
Graph
中表示的操作的语义,请参阅Node
。注意
此 API 的向后兼容性得到保证。
- __init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source]¶
构造一个空的图。
注意
此 API 的向后兼容性得到保证。
- call_function(the_function, args=None, kwargs=None, type_expr=None)[source]¶
在
Graph
中插入一个call_function
Node
。一个call_function
节点表示对由the_function
指定的 Python 可调用对象的调用。- 参数
the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 运算符、Python 函数或
builtins
或operator
命名空间的成员。args (Optional[Tuple[Argument, ...]]) – 传递给被调用函数的位置参数。
kwargs (可选[Dict[str, Argument]]) – 传递给被调用函数的关键字参数
type_expr (可选[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回值
新创建并插入的
call_function
节点。- 返回类型
注意
此方法与
Graph.create_node()
具有相同的插入点和类型表达式规则。注意
此 API 的向后兼容性得到保证。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[source]¶
在
Graph
中插入一个call_method
Node
。一个call_method
节点表示对args
中第 0 个元素的给定方法的调用。- 参数
- 返回值
新创建并插入的
call_method
节点。- 返回类型
注意
此方法与
Graph.create_node()
具有相同的插入点和类型表达式规则。注意
此 API 的向后兼容性得到保证。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[源代码]¶
在
Graph
中插入一个call_module
Node
。一个call_module
节点表示对Module
层次结构中Module
的 forward() 函数的调用。- 参数
- 返回值
新创建并插入的
call_module
节点。- 返回类型
注意
此方法与
Graph.create_node()
具有相同的插入点和类型表达式规则。注意
此 API 的向后兼容性得到保证。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source]¶
在当前插入点创建
Node
并将其添加到Graph
中。请注意,当前插入点可以通过Graph.inserting_before()
和Graph.inserting_after()
设置。- 参数
op (str) – 此节点的操作码。可以是 ‘call_function’,‘call_method’,‘get_attr’,‘call_module’,‘placeholder’ 或 ‘output’ 之一。这些操作码的语义在
Graph
文档字符串中描述。args (Optional[Tuple[Argument, ...]]) – 是此节点的参数元组。
kwargs (Optional[Dict[str, Argument]]) – 此节点的关键字参数
name (Optional[str]) –
Node
的可选字符串名称。这将影响在生成的 Python 代码中分配给该值的名称。type_expr (可选[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回值
新创建并插入的节点。
- 返回类型
注意
此 API 的向后兼容性得到保证。
- eliminate_dead_code()[source]¶
根据每个节点的用户数量以及节点是否有副作用,从图中删除所有死代码。调用之前必须对图进行拓扑排序。
- 返回值
该过程是否改变了图。
- 返回类型
示例
在消除死代码之前,下面 a = x + 1 中的 a 没有用户,因此可以从图中删除,而不会产生影响。
def forward(self, x): a = x + 1 return x + self.attr_1
在消除死代码之后,a = x + 1 已被删除,forward 的其余部分保留。
def forward(self, x): return x + self.attr_1
警告
死代码消除有一些启发式方法来避免删除有副作用的节点(参见 Node.is_impure),但总的来说覆盖率很差,因此您应该假设除非您知道您的 FX 图完全由函数操作组成,否则此方法不安全调用。
注意
此 API 的向后兼容性得到保证。
- erase_node(to_erase)[source]¶
从
Graph
中删除一个Node
。如果Graph
中仍然存在该节点的用户,则会抛出异常。- 参数
to_erase (Node) – 要从
Graph
中删除的Node
。
注意
此 API 的向后兼容性得到保证。
- get_attr(qualified_name, type_expr=None)[source]¶
在 Graph 中插入一个
get_attr
节点。一个get_attr
Node
代表从Module
层次结构中获取属性。- 参数
qualified_name (str) – 要检索的属性的完全限定名称。例如,如果跟踪的模块有一个名为
foo
的子模块,该子模块有一个名为bar
的子模块,该子模块有一个名为baz
的属性,则完全限定名称foo.bar.baz
应作为qualified_name
传递。type_expr (可选[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
- 返回值
新创建并插入的
get_attr
节点。- 返回类型
注意
此方法与
Graph.create_node
采用相同的插入点和类型表达式规则。注意
此 API 的向后兼容性得到保证。
- graph_copy(g, val_map, return_output_node=False)[source]¶
将给定图中的所有节点复制到
self
中。- 参数
- 返回值
现在等效于
g
中输出值的self
中的值,如果g
有一个output
节点。否则为None
。- 返回类型
可选[联合[元组[任意, …], 列表[任意], 字典[str, 任意], 切片, 范围, 节点, str, int, float, bool, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 运算符重载]]
注意
此 API 的向后兼容性得到保证。
- inserting_after(n=None)[源代码]¶
- 设置创建节点和配套方法插入图表的点。
在“with”语句中使用时,这将临时设置插入点,然后在“with”语句退出时恢复。
with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
参数
- n (可选[节点]): 要插入的节点之前的节点。如果为 None,则将在
整个图的开头之后插入。
- 返回值
一个资源管理器,它将在
__exit__
上恢复插入点。
注意
此 API 的向后兼容性得到保证。
- inserting_before(n=None)[source]¶
- 设置创建节点和配套方法插入图表的点。
在“with”语句中使用时,这将临时设置插入点,然后在“with”语句退出时恢复。
with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
参数
- n (可选[节点]): 要插入的节点之前的节点。如果为 None,则将在
整个图的开头之后插入。
- 返回值
一个资源管理器,它将在
__exit__
上恢复插入点。
注意
此 API 的向后兼容性得到保证。
- lint()[source]¶
对该图运行各种检查以确保其格式正确。特别是: - 检查节点是否具有正确的所有权(由该图拥有) - 检查节点是否按拓扑顺序出现 - 如果该图具有拥有 GraphModule,则检查目标是否存在于该 GraphModule 中
注意
此 API 的向后兼容性得到保证。
- node_copy(node, arg_transform=<function Graph.<lambda>>)[source]¶
将一个节点从一个图复制到另一个图。
arg_transform
需要将参数从节点的图转换为 self 的图。示例# Copying all the nodes in `g` into `new_graph` g : torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
- 参数
- 返回类型
注意
此 API 的向后兼容性得到保证。
- property nodes: _node_list¶
获取构成此图的节点列表。
请注意,此
Node
列表表示是一个双向链表。在迭代期间进行的修改(例如删除节点、添加节点)是安全的。- 返回值
节点的双向链表。请注意,可以对该列表调用
reversed
来切换迭代顺序。
- on_generate_code(make_transformer)[source]¶
在生成 Python 代码时注册一个转换器函数
- 参数
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc])
一个返回要注册的代码转换器的函数。此函数由 on_generate_code 调用以获取代码转换器。
此函数还以当前注册的代码转换器(如果未注册任何内容,则为 None)作为其输入,以防不希望覆盖它。这对于将代码转换器链接在一起很有用。
- 返回值
一个上下文管理器,当在 with 语句中使用时,会自动恢复先前注册的代码转换器。
示例
gm: fx.GraphModule = ... # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual # debugging with the PDB library. def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): gm.graph.on_generate_code( lambda _: insert_pdb ) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb( current_trans(body) if current_trans else body ) ) ) gm.recompile() gm(*inputs) # drops into pdb
此函数也可以用作上下文管理器,它具有自动恢复先前注册的代码转换器的优点
# ... continue from previous example with gm.graph.on_generate_code(lambda _: insert_pdb): # do more stuff with `gm`... gm.recompile() gm(*inputs) # drops into pdb # now previous code transformer is restored (but `gm`'s code with pdb # remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告
此 API 处于实验阶段,不 向后兼容。
- output(result, type_expr=None)[source]¶
在
Graph
中插入一个output
Node
。一个output
节点表示 Python 代码中的return
语句。result
是要返回的值。- 参数
结果 (参数) – 要返回的值。
type_expr (可选[Any]) – 一个可选的类型注解,表示此节点输出的 Python 类型。
注意
此方法与
Graph.create_node
采用相同的插入点和类型表达式规则。注意
此 API 的向后兼容性得到保证。
- 占位符(名称, 类型表达式=None, 默认值)[源代码]¶
在图中插入一个
占位符
节点。一个占位符
代表一个函数输入。- 参数
名称 (str) – 输入值的名称。这对应于此
图
所代表的函数的位置参数的名称。类型表达式 (可选[任何]) – 一个可选的类型注释,表示此节点输出将具有的 Python 类型。在某些情况下,这对于正确的代码生成是必需的(例如,当函数随后用于 TorchScript 编译时)。
默认值 (任何) – 此函数参数应采用的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 作为此参数传递,以指定该参数 _没有_ 默认值。
- 返回类型
注意
此方法与
Graph.create_node
采用相同的插入点和类型表达式规则。注意
此 API 的向后兼容性得到保证。
- class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source]¶
Node
是表示Graph
中单个操作的数据结构。在大多数情况下,节点表示对各种实体的调用点,例如运算符、方法和模块(一些例外包括指定函数输入和输出的节点)。每个Node
都具有由其op
属性指定的函数。每个op
值的Node
语义如下placeholder
表示函数输入。name
属性指定此值将采用的名称。target
同样是参数的名称。args
包含以下内容之一:1) 没有任何内容,或 2) 表示函数输入的默认参数的单个参数。kwargs
是无关紧要的。占位符对应于图打印输出中的函数参数(例如x
)。get_attr
从模块层次结构中检索参数。name
同样是获取结果的名称。target
是参数在模块层次结构中的位置的完全限定名称。args
和kwargs
是无关紧要的。call_function
将自由函数应用于某些值。name
同样是分配给值的名称。target
是要应用的函数。args
和kwargs
代表函数的参数,遵循 Python 调用约定。call_module
将模块层次结构中的模块的forward()
方法应用于给定的参数。name
与之前相同。target
是模块层次结构中要调用的模块的完全限定名称。args
和kwargs
代表调用模块的参数,*不包括 self 参数*。call_method
在值上调用方法。name
与之前相同。target
是要应用于self
参数的方法的字符串名称。args
和kwargs
代表调用模块的参数,*包括 self 参数*。output
在其args[0]
属性中包含跟踪函数的输出。 这对应于 Graph 打印中的“return”语句。
注意
此 API 的向后兼容性得到保证。
- property all_input_nodes: List[Node]¶
返回所有作为此节点输入的节点。这等效于遍历
args
和kwargs
,并仅收集是节点的值。- 返回值
出现在此
Node
的args
和kwargs
中的Nodes
列表,按顺序排列。
- append(x)[source]¶
在图中节点列表中,将
x
插入到此节点之后。等效于self.next.prepend(x)
- 参数
x (Node) – 要放在此节点之后的节点。必须是同一图的成员。
注意
此 API 的向后兼容性得到保证。
- property args: Tuple[Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload]], ...]¶
此
Node
的参数元组。参数的解释取决于节点的 opcode。有关更多信息,请参阅Node
文档字符串。允许对此属性进行赋值。所有使用和用户的使用情况都会在赋值时自动更新。
- format_node(placeholder_names=None, maybe_return_typename=None)[source]¶
返回
self
的描述性字符串表示。此方法可以用作调试工具,无需参数。
此函数还在
Graph
的__str__
方法中内部使用。placeholder_names
和maybe_return_typename
中的字符串共同构成了此 Graph 周围 GraphModule 中自动生成的forward
函数的签名。placeholder_names
和maybe_return_typename
不应以其他方式使用。- 参数
- 返回值
- 如果 1) 我们正在使用
format_node
作为内部辅助函数 在
Graph
的__str__
方法中,以及 2)self
是一个占位符节点,则返回None
。否则,返回当前节点的描述性字符串表示。
- 如果 1) 我们正在使用
- 返回类型
注意
此 API 的向后兼容性得到保证。
- insert_arg(idx, arg)[source]¶
在给定索引处将位置参数插入参数列表。
- 参数
idx (整数) – 要在
self.args
中插入之前元素的索引。arg (参数) – 要插入
args
的新参数值
注意
此 API 的向后兼容性得到保证。
- is_impure()[source]¶
返回此操作是否为不纯的,即其操作是否为占位符或输出,或者是否为不纯的 call_function 或 call_module。
- 返回值
如果操作是不纯的或不是不纯的。
- 返回类型
警告
此 API 处于实验阶段,不 向后兼容。
- property kwargs: Dict[str, Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload]]]¶
此
Node
的关键字参数字典。参数的解释取决于节点的操作码。有关更多信息,请参阅Node
文档字符串。允许对此属性进行赋值。所有使用和用户的使用情况都会在赋值时自动更新。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source]¶
返回规范化的参数到 Python 目标。这意味着 args/kwargs 将与模块/函数的签名匹配,如果 normalize_to_only_use_kwargs 为真,则仅按位置顺序返回 kwargs。还会填充默认值。不支持仅位置参数或可变参数参数。
支持模块调用。
可能需要 arg_types 和 kwarg_types 来消除重载歧义。
- 参数
root (torch.nn.Module) – 要解析模块目标的模块。
arg_types (Optional[Tuple[Any]]) – 参数类型的元组
kwarg_types (Optional[Dict[str, Any]]) – 关键字参数类型的字典
normalize_to_only_use_kwargs (bool) – 是否只使用关键字参数进行规范化。
- 返回值
返回 NamedTuple ArgsKwargsPair,如果未成功则返回 None。
- 返回类型
Optional[ArgsKwargsPair]
警告
此 API 处于实验阶段,不 向后兼容。
- prepend(x)[source]¶
在图中的节点列表中,将 x 插入到此节点之前。示例
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- 参数
x (Node) – 要放在此节点之前的节点。必须是同一图的成员。
注意
此 API 的向后兼容性得到保证。
- replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[source]¶
在图中用节点
replace_with
替换所有对self
的使用。- 参数
- 返回值
此更改所做的节点列表。
- 返回类型
注意
此 API 的向后兼容性得到保证。
- replace_input_with(old_input, new_input)[source]¶
遍历
self
的输入节点,并将所有old_input
实例替换为new_input
。注意
此 API 的向后兼容性得到保证。
- property stack_trace: Optional[str]¶
返回在跟踪期间记录的 Python 堆栈跟踪(如果有)。当使用 fx.Tracer 跟踪时,此属性通常由 Tracer.create_proxy 填充。为了在跟踪期间记录堆栈跟踪以进行调试,请在 Tracer 实例上设置 record_stack_traces = True。当使用 dynamo 跟踪时,此属性将默认由 OutputGraph.create_proxy 填充。
stack_trace 将在字符串的末尾包含最内层的帧。
- class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source]¶
Tracer
是实现torch.fx.symbolic_trace
符号跟踪功能的类。调用symbolic_trace(m)
等同于Tracer().trace(m)
。可以对 Tracer 进行子类化,以覆盖跟踪过程的各种行为。可以在此类方法的文档字符串中找到可以覆盖的不同行为。
注意
此 API 的向后兼容性得到保证。
- call_module(m, forward, args, kwargs)[source]¶
此方法指定
Tracer
在遇到对nn.Module
实例的调用时的行为。默认情况下,该行为是检查调用的模块是否为叶模块,方法是使用
is_leaf_module
。如果是,则在Graph
中发出一个引用m
的call_module
节点。否则,正常调用Module
,跟踪其forward
函数中的操作。此方法可以被重写,例如,创建嵌套的跟踪 GraphModules,或者在跨越
Module
边界时你想要的任何其他行为。- 参数
m (Module) – 要发出调用的模块
forward (Callable) – 要调用的
Module
的 forward() 方法args (Tuple) – 模块调用点的参数
kwargs (Dict) – 模块调用点的关键字参数
- 返回值
来自 Module 调用的返回值。在发出
call_module
节点的情况下,这是一个Proxy
值。否则,它是从Module
调用返回的任何值。- 返回类型
注意
此 API 的向后兼容性得到保证。
- create_arg(a)[source]¶
一种指定在准备要作为
Graph
中节点的参数使用的值时跟踪行为的方法。默认情况下,该行为包括
遍历集合类型(例如元组、列表、字典)并递归地对元素调用
create_args
。给定一个 Proxy 对象,返回对底层 IR
Node
的引用给定一个非 Proxy 张量对象,为各种情况发出 IR
对于一个参数,发出一个
get_attr
节点,该节点引用该参数对于一个非参数张量,将张量存储在一个特殊属性中,该属性引用该属性。
此方法可以被重写以支持更多类型。
- 参数
a (Any) – 要作为
Graph
中的Argument
发出的值。- 返回值
值
a
转换为适当的Argument
- 返回类型
可选[联合[元组[任意, …], 列表[任意], 字典[str, 任意], 切片, 范围, 节点, str, int, float, bool, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 运算符重载]]
注意
此 API 的向后兼容性得到保证。
- create_args_for_root(root_fn, is_module, concrete_args=None)[source]¶
为
root
模块的签名创建placeholder
节点。此方法会检查 root 的签名并相应地发出这些节点,也支持*args
和**kwargs
。警告
此 API 处于实验阶段,不 向后兼容。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)¶
插入给定目标、参数、关键字参数和名称的图形节点。
此方法可以被重写以对节点创建中使用的值进行额外的检查、验证或修改。例如,您可能希望禁止记录就地操作。
注意
此 API 的向后兼容性得到保证。
- 返回类型
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)¶
从给定参数创建节点,然后返回用代理对象包装的节点。
如果 kind 等于 'placeholder',那么我们正在创建一个表示函数参数的节点。如果我们需要编码默认参数,我们使用
args
元组。对于placeholder
节点,args
通常为空。注意
此 API 的向后兼容性得到保证。
- getattr(attr, attr_val, parameter_proxy_cache)[source]¶
当我们调用
nn.Module
实例的 getattr 时,此方法指定了Tracer
的行为。默认情况下,行为是返回属性的代理值。它还将代理值存储在
parameter_proxy_cache
中,以便将来调用重用代理而不是创建新的代理。此方法可以被覆盖,例如,在查询参数时不返回代理。
- 参数
- 返回值
getattr 调用返回的值。
警告
此 API 处于实验阶段,不 向后兼容。
- is_leaf_module(m, module_qualified_name)[source]¶
一个方法来指定给定的
nn.Module
是否是“叶子”模块。叶子模块是 IR 中出现的原子单元,由
call_module
调用引用。默认情况下,PyTorch 标准库命名空间 (torch.nn) 中的模块是叶子模块。所有其他模块都将被跟踪,并且它们的组成操作将被记录,除非通过此参数另行指定。- 参数
- 返回类型
注意
此 API 的向后兼容性得到保证。
- iter(obj)¶
- 当代理对象被迭代时调用,例如
在控制流中使用。通常我们不知道该怎么做,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个迭代器。
注意
此 API 的向后兼容性得到保证。
- 返回类型
- keys(obj)¶
- 当代理对象调用 keys() 方法时调用。
这就是在代理上调用 ** 时发生的事情。这应该返回一个迭代器,它应该在您的自定义跟踪器中工作。
注意
此 API 的向后兼容性得到保证。
- 返回类型
- path_of_module(mod)[source]¶
辅助方法,用于在
root
的模块层次结构中查找mod
的限定名称。例如,如果root
有一个名为foo
的子模块,该子模块有一个名为bar
的子模块,将bar
传递到此函数将返回字符串“foo.bar”。注意
此 API 的向后兼容性得到保证。
- to_bool(obj)¶
- 当代理对象被转换为布尔值时调用,例如
在控制流中使用。通常我们不知道该怎么做,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个值。
注意
此 API 的向后兼容性得到保证。
- 返回类型
- class torch.fx.Proxy(node, tracer=None)[source]¶
Proxy
对象是Node
包装器,在符号跟踪期间流经程序,并将它们触及的所有操作(torch
函数调用、方法调用、运算符)记录到不断增长的 FX 图中。如果您正在进行图形转换,您可以将自己的
Proxy
方法包装在原始Node
周围,以便您可以使用重载运算符向Graph
添加其他内容。Proxy
对象不可迭代。换句话说,如果在循环中或作为*args
/**kwargs
函数参数使用Proxy
,符号跟踪器将抛出错误。有两种主要方法可以解决这个问题:1. 将不可跟踪的逻辑分解到顶层函数中,并对其使用
fx.wrap
。2. 如果控制流是静态的(即循环次数基于某个超参数),则代码可以保留在原始位置并重构为类似于以下内容:for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
有关 Proxy 内部结构的更详细说明,请查看 torch/fx/OVERVIEW.md 中的“Proxy”部分。
注意
此 API 的向后兼容性得到保证。
- class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[source]¶
解释器逐节点执行 FX 图。这种模式对于许多事情很有用,包括编写代码转换以及分析过程。
解释器类中的方法可以被重写以自定义执行的行为。可重写方法的映射(按调用层次结构)
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
示例
假设我们要将所有
torch.neg
实例替换为torch.sigmoid
,反之亦然(包括它们的Tensor
方法等效项)。我们可以像这样子类化解释器class NegSigmSwapInterpreter(Interpreter): def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid())
- 参数
module (torch.nn.Module) – 要执行的模块
garbage_collect_values (bool) – 是否在模块执行过程中删除值最后一次使用后的值。这确保了执行期间最佳的内存使用。可以禁用此功能,例如,通过查看
Interpreter.env
属性来检查执行中的所有中间值。graph (Optional[Graph]) – 如果传递,解释器将执行此图,而不是 module.graph,使用提供的 module 参数来满足对状态的任何请求。
注意
此 API 的向后兼容性得到保证。
- boxed_run(args_list)[source]¶
通过解释运行 module 并返回结果。这使用“boxed”调用约定,您传递一个参数列表,该列表将由解释器清除。这确保了输入张量会立即释放。
注意
此 API 的向后兼容性得到保证。
- call_function(target, args, kwargs)[source]¶
执行
call_function
节点并返回结果。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数元组
kwargs (Dict) – 此调用的关键字参数字典
- 返回类型
- 返回值
Any: 函数调用返回的值
注意
此 API 的向后兼容性得到保证。
- call_method(target, args, kwargs)[source]¶
执行
call_method
节点并返回结果。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数元组
kwargs (Dict) – 此调用的关键字参数字典
- 返回类型
- 返回值
Any: 方法调用返回的值
注意
此 API 的向后兼容性得到保证。
- call_module(target, args, kwargs)[source]¶
执行一个
call_module
节点并返回结果。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数元组
kwargs (Dict) – 此调用的关键字参数字典
- 返回类型
- 返回值
Any: 模块调用返回的值
注意
此 API 的向后兼容性得到保证。
- fetch_args_kwargs_from_env(n)[source]¶
从当前执行环境中获取节点
n
的args
和kwargs
的具体值。- 参数
n (Node) – 需要获取
args
和kwargs
的节点。- 返回值
args
和kwargs
具有n
的具体值。- 返回类型
Tuple[Tuple, Dict]
注意
此 API 的向后兼容性得到保证。
- fetch_attr(target)[source]¶
从
self.module
的Module
层次结构中获取属性。- 参数
target (str) – 要获取的属性的完全限定名称
- 返回值
属性的值。
- 返回类型
任何
注意
此 API 的向后兼容性得到保证。
- get_attr(target, args, kwargs)[source]¶
执行一个
get_attr
节点。将从self.module
的Module
层次结构中检索属性值。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数元组
kwargs (Dict) – 此调用的关键字参数字典
- 返回值
检索到的属性的值
- 返回类型
任何
注意
此 API 的向后兼容性得到保证。
- map_nodes_to_values(args, n)[source]¶
递归地遍历
args
并查找当前执行环境中每个Node
的具体值。- 参数
args (Argument) – 用于查找具体值的数据结构
n (Node) –
args
所属的节点。这仅用于错误报告。
- 返回类型
可选[联合[元组[任意, …], 列表[任意], 字典[str, 任意], 切片, 范围, 节点, str, int, float, bool, 复数, 数据类型, 张量, 设备, 内存格式, 布局, 运算符重载]]
注意
此 API 的向后兼容性得到保证。
- output(target, args, kwargs)[source]¶
执行一个
output
节点。这实际上只是检索output
节点引用的值并返回它。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数元组
kwargs (Dict) – 此调用的关键字参数字典
- 返回值
输出节点引用的返回值
- 返回类型
任何
注意
此 API 的向后兼容性得到保证。
- placeholder(target, args, kwargs)[source]¶
执行一个
placeholder
节点。注意,这是有状态的:Interpreter
在传递给run
的参数上维护一个内部迭代器,此方法返回该迭代器的 next()。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数元组
kwargs (Dict) – 此调用的关键字参数字典
- 返回值
检索到的参数值。
- 返回类型
任何
注意
此 API 的向后兼容性得到保证。
- class torch.fx.Transformer(module)[source]¶
Transformer
是一种特殊的解释器类型,它会生成一个新的Module
。它公开了一个transform()
方法,该方法返回转换后的Module
。Transformer
不需要像Interpreter
那样运行参数。Transformer
完全以符号方式工作。示例
假设我们要将所有
torch.neg
实例替换为torch.sigmoid
,反之亦然(包括它们的Tensor
方法等效项)。我们可以像这样子类化Transformer
class NegSigmSwapXformer(Transformer): def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- 参数
module (GraphModule) – 要转换的
Module
。
注意
此 API 的向后兼容性得到保证。
- get_attr(target, args, kwargs)[source]¶
执行一个
get_attr
节点。在Transformer
中,这被重写为在输出图中插入一个新的get_attr
节点。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数元组
kwargs (Dict) – 此调用的关键字参数字典
- 返回类型
注意
此 API 的向后兼容性得到保证。
- placeholder(target, args, kwargs)[source]¶
执行一个
placeholder
节点。在Transformer
中,这被重写为在输出图中插入一个新的placeholder
。- 参数
target (Target) – 此节点的调用目标。有关语义的详细信息,请参见 Node
args (Tuple) – 此调用的位置参数元组
kwargs (Dict) – 此调用的关键字参数字典
- 返回类型
注意
此 API 的向后兼容性得到保证。
- torch.fx.replace_pattern(gm, pattern, replacement)[source]¶
在 GraphModule (
gm
) 的图中匹配所有可能的非重叠运算符集及其数据依赖项 (pattern
),然后用另一个子图 (replacement
) 替换每个匹配的子图。- 参数
gm (GraphModule) – 包含要操作的图的 GraphModule
pattern (Union[Callable, GraphModule]) – 要在
gm
中匹配以进行替换的子图替换 (Union[Callable, GraphModule]) – 用来替换
pattern
的子图
- 返回值
一个
Match
对象列表,表示在原始图中pattern
被匹配到的位置。如果没有任何匹配,则列表为空。Match
定义为class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
- 返回类型
List[Match]
示例
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上面的代码将首先在
traced_module
的forward
方法中匹配pattern
。模式匹配是基于使用-定义关系,而不是节点名称。例如,如果你在pattern
中有p = torch.cat([a, b])
,你可以在原始的forward
函数中匹配m = torch.cat([a, b])
,尽管变量名不同(p
vsm
)。pattern
中的return
语句仅基于其值进行匹配;它可能匹配也可能不匹配到更大图中的return
语句。换句话说,模式不必扩展到更大图的末尾。当模式匹配时,它将从更大的函数中删除,并用
replacement
替换。如果在更大的函数中pattern
有多个匹配项,则每个不重叠的匹配项都将被替换。在匹配重叠的情况下,将替换重叠匹配集中找到的第一个匹配项。(“第一个”在这里定义为节点使用-定义关系的拓扑排序中的第一个。在大多数情况下,第一个节点是直接出现在self
之后的参数,而最后一个节点是函数返回的任何内容。)需要注意的一点是,
pattern
可调用对象的参数必须在可调用对象本身中使用,而replacement
可调用对象的参数必须与模式匹配。第一个规则是为什么在上面的代码块中,forward
函数具有参数x, w1, w2
,但pattern
函数只有参数w1, w2
。pattern
不使用x
,因此它不应该将x
指定为参数。作为第二个规则的示例,考虑替换def pattern(x, y): return torch.neg(x) + torch.relu(y)
与
def replacement(x, y): return torch.relu(x)
在这种情况下,
replacement
需要与pattern
相同数量的参数(x
和y
),即使参数y
在replacement
中未使用。调用
subgraph_rewriter.replace_pattern
后,生成的 Python 代码如下所示def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
此 API 的向后兼容性得到保证。