快捷方式

torch.fx

概述

FX 是一个工具包,开发人员可以使用它来转换 nn.Module 实例。FX 由三个主要组件组成:符号跟踪器中间表示Python 代码生成。这些组件实际应用的演示

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param : [num_users=1] = get_attr[target=param]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号跟踪器执行 Python 代码的“符号执行”。它通过代码传递虚假值,称为代理。记录对这些代理的操作。有关符号跟踪的更多信息,请参阅 symbolic_trace()Tracer 文档。

中间表示是符号跟踪期间记录的操作的容器。它由表示函数输入、调用点(对函数、方法或 torch.nn.Module 实例)和返回值的节点列表组成。有关 IR 的更多信息,请参阅 Graph 的文档。IR 是应用转换的格式。

Python 代码生成使 FX 成为 Python 到 Python(或模块到模块)转换工具包。对于每个 Graph IR,我们可以创建与 Graph 语义匹配的有效 Python 代码。此功能封装在 GraphModule 中,它是一个 torch.nn.Module 实例,它包含一个 Graph 以及从 Graph 生成的 forward 方法。

总而言之,此组件管道(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python 到 Python 转换管道。此外,这些组件可以单独使用。例如,符号跟踪可以独立使用以捕获代码的一种形式以进行分析(而不是转换)目的。代码生成可用于以编程方式生成模型,例如从配置文件生成。FX 有很多用途!

可以在 示例 存储库中找到几个示例转换。

编写转换

什么是 FX 转换?从本质上讲,它是一个类似于这样的函数。

import torch
import torch.fx

def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: Modify this Graph or create a new one
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

您的转换将接收一个 torch.nn.Module,从中获取一个 Graph,进行一些修改,并返回一个新的 torch.nn.Module。您应该将 FX 转换返回的 torch.nn.Module 视为与常规 torch.nn.Module 相同——您可以将其传递给另一个 FX 转换,可以将其传递给 TorchScript,或者可以运行它。确保 FX 转换的输入和输出是 torch.nn.Module 将允许组合性。

注意

也可以修改现有的 GraphModule 而不是创建一个新的,如下所示

import torch
import torch.fx

def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # Modify gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

请注意,您必须调用 GraphModule.recompile() 以使 GraphModule 上生成的 forward() 方法与修改后的 Graph 保持同步。

鉴于您传入了一个已被跟踪到 Graphtorch.nn.Module,现在您可以采用两种主要方法来构建新的 Graph

图的快速入门

关于图语义的完整处理可以在Graph文档中找到,但我们在这里将介绍基础知识。Graph是一种数据结构,用于表示GraphModule上的一个方法。这需要的信息是

  • 方法的输入是什么?

  • 方法内部运行的操作是什么?

  • 方法的输出(即返回值)是什么?

所有这三个概念都用Node实例表示。让我们通过一个简短的例子来看看这意味着什么

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

这里我们定义一个模块MyModule用于演示目的,实例化它,对其进行符号跟踪,然后调用Graph.print_tabular()方法打印一个表格,显示该Graph的节点

操作码

名称

目标

参数

关键字参数

占位符

x

x

()

{}

获取属性

线性权重

linear.weight

()

{}

调用函数

加法1

<内置函数 add>

(x, linear_weight)

{}

调用模块

线性1

线性

(add_1,)

{}

调用方法

ReLU 1

ReLU

(linear_1,)

{}

调用函数

求和1

<内置方法 sum …>

(relu_1,)

{'dim': -1}

调用函数

topk 1

<内置方法 topk …>

(sum_1, 3)

{}

输出

输出

输出

(topk_1,)

{}

我们可以使用这些信息来回答我们上面提出的问题。

  • 方法的输入是什么?在 FX 中,方法输入通过特殊的placeholder节点指定。在本例中,我们有一个placeholder节点,其targetx,这意味着我们有一个名为 x 的单个(非自身)参数。

  • 方法中的操作是什么?get_attrcall_functioncall_modulecall_method节点表示方法中的操作。所有这些语义的完整处理可以在Node文档中找到。

  • 方法的返回值是什么?Graph中的返回值由一个特殊的output节点指定。

鉴于我们现在了解了 FX 中代码表示的基本知识,我们现在可以探索如何编辑Graph

图操作

直接图操作

构建这个新的Graph的一种方法是直接操作旧的图。为了帮助做到这一点,我们可以简单地获取从符号跟踪中获得的Graph并对其进行修改。例如,假设我们希望用torch.mul()调用替换torch.add()调用。

import torch
import torch.fx

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

我们还可以进行更复杂的Graph重写,例如删除或追加节点。为了帮助进行这些转换,FX 具有用于转换图的实用函数,这些函数可以在Graph文档中找到。下面是一个使用这些 API 追加torch.relu()调用的示例。

# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
    # Insert a new `call_function` node calling `torch.relu`
    new_node = traced.graph.call_function(
        torch.relu, args=(node,))

    # We want all places that used the value of `node` to
    # now use that value after the `relu` call we've added.
    # We use the `replace_all_uses_with` API to do this.
    node.replace_all_uses_with(new_node)

对于仅包含替换的简单转换,您还可以使用子图重写器

使用 replace_pattern() 进行子图重写

FX 还提供了比直接图操作更高级别的自动化。该replace_pattern()API 本质上是用于编辑Graph的“查找/替换”工具。它允许您指定一个patternreplacement函数,它将遍历这些函数,查找pattern图中操作组的实例,并用replacement图的副本替换这些实例。这有助于极大地自动化乏味的图操作代码,因为随着转换变得越来越复杂,这些代码可能会变得难以管理。

代理/重新跟踪

操作Graph的另一种方法是重用符号跟踪中使用的Proxy机制。例如,假设我们想要编写一个将 PyTorch 函数分解成更小操作的转换。它将每个F.relu(x)调用转换为(x > 0) * x。一种可能性是在F.relu之后插入比较和乘法以执行必要的图重写,然后清理原始的F.relu。但是,我们可以通过使用Proxy对象自动将操作追加到Graph来自动化此过程。

要使用此方法,我们将要插入的操作编写为常规 PyTorch 代码,并使用Proxy对象作为参数调用该代码。这些Proxy对象将捕获对其执行的操作并将其追加到Graph

# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
    return (x > 0) * x

decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition

def decompose(model: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    """
    Decompose `model` into smaller constituent operations.
    Currently,this only supports decomposing ReLU into its
    mathematical definition: (x > 0) * x
    """
    graph : fx.Graph = tracer_class().trace(model)
    new_graph = fx.Graph()
    env = {}
    tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
    for node in graph.nodes:
        if node.op == 'call_function' and node.target in decomposition_rules:
            # By wrapping the arguments with proxies,
            # we can dispatch to the appropriate
            # decomposition rule and implicitly add it
            # to the Graph by symbolically tracing it.
            proxy_args = [
                fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
            output_proxy = decomposition_rules[node.target](*proxy_args)

            # Operations on `Proxy` always yield new `Proxy`s, and the
            # return value of our decomposition rule is no exception.
            # We need to extract the underlying `Node` from the `Proxy`
            # to use it in subsequent iterations of this transform.
            new_node = output_proxy.node
            env[node.name] = new_node
        else:
            # Default case: we don't have a decomposition rule for this
            # node, so just copy the node over into the new graph.
            new_node = new_graph.node_copy(node, lambda x: env[x.name])
            env[node.name] = new_node
    return fx.GraphModule(model, new_graph)

除了避免显式图操作外,使用Proxy还允许您将重写规则指定为本机 Python 代码。对于需要大量重写规则的转换(例如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。请注意,在调用Proxy时,我们还传递了一个指向底层变量graph的跟踪器。这样做是为了防止图中的操作是 n 元的(例如,add 是一个二元运算符)时,对Proxy的调用不会创建图跟踪器的多个实例,这可能导致意外的运行时错误。我们建议使用这种Proxy方法,尤其是在不能安全地假设底层运算符是一元的时。

使用Proxy进行Graph操作的示例可以在此处找到。

解释器模式

FX 中一个有用的代码组织模式是循环遍历Graph中的所有Node并执行它们。这可用于多种用途,包括运行时分析流经图的值或通过使用Proxy重新跟踪来转换代码。例如,假设我们想要运行一个GraphModule并在运行时记录节点上的torch.Tensor形状和数据类型属性。这可能看起来像

import torch
import torch.fx
from torch.fx.node import Node

from typing import Dict

class ShapeProp:
    """
    Shape propagation. This class takes a `GraphModule`.
    Then, its `propagate` method executes the `GraphModule`
    node-by-node with the given arguments. As each operation
    executes, the ShapeProp class stores away the shape and
    element type for the output values of each operation on
    the `shape` and `dtype` attributes of the operation's
    `Node`.
    """
    def __init__(self, mod):
        self.mod = mod
        self.graph = mod.graph
        self.modules = dict(self.mod.named_modules())

    def propagate(self, *args):
        args_iter = iter(args)
        env : Dict[str, Node] = {}

        def load_arg(a):
            return torch.fx.graph.map_arg(a, lambda n: env[n.name])

        def fetch_attr(target : str):
            target_atoms = target.split('.')
            attr_itr = self.mod
            for i, atom in enumerate(target_atoms):
                if not hasattr(attr_itr, atom):
                    raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
                attr_itr = getattr(attr_itr, atom)
            return attr_itr

        for node in self.graph.nodes:
            if node.op == 'placeholder':
                result = next(args_iter)
            elif node.op == 'get_attr':
                result = fetch_attr(node.target)
            elif node.op == 'call_function':
                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
            elif node.op == 'call_method':
                self_obj, *args = load_arg(node.args)
                kwargs = load_arg(node.kwargs)
                result = getattr(self_obj, node.target)(*args, **kwargs)
            elif node.op == 'call_module':
                result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))

            # This is the only code specific to shape propagation.
            # you can delete this `if` branch and this becomes
            # a generic GraphModule interpreter.
            if isinstance(result, torch.Tensor):
                node.shape = result.shape
                node.dtype = result.dtype

            env[node.name] = result

        return load_arg(self.graph.result)

如您所见,FX 的完整解释器并不复杂,但它非常有用。为了方便使用此模式,我们提供了 Interpreter 类,它以一种可以覆盖解释器执行的某些方面的方式包含了上述逻辑。

除了执行操作之外,我们还可以通过将 Proxy 值馈送到解释器来生成新的 Graph。类似地,我们提供了 Transformer 类来包含此模式。Transformer 的行为类似于 Interpreter,但它不是调用 run 方法从模块获取具体的输出值,而是调用 Transformer.transform() 方法返回一个新的 GraphModule,该模块已受到您作为覆盖方法安装的任何转换规则的影响。

解释器模式示例

调试

简介

在编写转换的过程中,我们的代码通常不会完全正确。在这种情况下,我们可能需要进行一些调试。关键是反向工作:首先,检查调用生成的模块的结果以证明或反驳正确性。然后,检查和调试生成的代码。然后,调试导致生成代码的转换过程。

如果您不熟悉调试器,请参阅辅助部分 可用的调试器

转换编写中的常见陷阱

  • 非确定性 set 迭代顺序。在 Python 中,set 数据类型是无序的。例如,使用 set 来包含像 Node 这样的对象的集合会导致意外的非确定性。一个例子是遍历一组 Node 并将其插入 Graph 中。因为 set 数据类型是无序的,所以输出程序中操作的顺序将是非确定性的,并且可以在程序调用之间发生变化。建议的替代方法是使用 dict 数据类型,该类型从 Python 3.7(以及 cPython 3.6)开始是 插入有序的。通过将要进行重复数据删除的值存储在 dict 的键中,dict 可以等效地用作集合。

检查模块的正确性

由于大多数深度学习模块的输出由浮点 torch.Tensor 实例组成,因此检查两个 torch.nn.Module 结果之间的等价性并不像进行简单的相等性检查那样简单。为了说明这一点,让我们举一个例子

import torch
import torch.fx
import torchvision.models as models

def transform(m : torch.nn.Module) -> torch.nn.Module:
    gm = torch.fx.symbolic_trace(m)

    # Imagine we're doing some transforms here
    # <...>

    gm.recompile()

    return gm

resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)

input_image = torch.randn(5, 3, 224, 224)

assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""

在这里,我们尝试使用 == 相等运算符检查两个深度学习模型的值的相等性。但是,由于运算符返回张量而不是布尔值,以及浮点值的比较应使用误差范围(或 epsilon)来解释浮点运算的非交换性(有关更多详细信息,请参阅 此处),因此这在定义上是不好的。我们可以改用 torch.allclose(),它将提供一个近似的比较,同时考虑相对和绝对容差阈值

assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))

这是我们工具箱中第一个工具,用于检查转换后的模块是否符合我们的预期,并与参考实现进行比较。

调试生成的代码

由于 FX 在 GraphModule 上生成 forward() 函数,因此使用传统的调试技术(如 print 语句或 pdb)并不那么简单。幸运的是,我们有一些技术可用于调试生成的代码。

使用 pdb

调用 pdb 以进入正在运行的程序。虽然表示 Graph 的代码不在任何源文件中,但我们仍然可以使用 pdb 手动进入它,当调用前向传递时。

import torch
import torch.fx
import torchvision.models as models

def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(inp)
    # Transformation logic here
    # <...>

    # Return new Module
    return fx.GraphModule(inp, graph)

my_module = models.resnet18()
my_module_transformed = my_pass(my_module)

input_value = torch.randn(5, 3, 224, 224)

# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()

my_module_transformed(input_value)

使用 GraphModule 中的 to_folder 函数

GraphModule.to_folder()GraphModule 中的一种方法,它允许您将生成的 FX 代码转储到一个文件夹中。虽然将前向传递复制到代码中通常足以满足 打印生成的代码 的要求,但使用 to_folder 检查模块和参数可能更容易。

m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()

运行上述示例后,我们可以查看 foo/module.py 中的代码并根据需要修改它(例如,添加 print 语句或使用 pdb)来调试生成的代码。

调试转换

现在我们已经确定转换正在创建不正确的代码,是时候调试转换本身了。首先,我们将检查文档中的 符号跟踪的局限性 部分。一旦我们验证跟踪按预期工作,目标就变成了找出我们的 GraphModule 转换过程中哪里出了问题。在 编写转换 中可能有一个快速答案,但如果没有,则有几种方法可以检查我们跟踪的模块

# Sample Module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y

# Create an instance of `M`
m = M()

# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)

# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
    add = x + y;  x = y = None
    return add
"""

# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
    %x : [num_users=1] = placeholder[target=x]
    %y : [num_users=1] = placeholder[target=y]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    return add
"""

# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode         name    target                   args    kwargs
-------------  ------  -----------------------  ------  --------
placeholder    x       x                        ()      {}
placeholder    y       y                        ()      {}
call_function  add     <built-in function add>  (x, y)  {}
output         output  output                   (add,)  {}
"""

使用上面的实用程序函数,我们可以比较应用转换之前和之后的跟踪模块。有时,简单的视觉比较足以跟踪错误。如果仍然不清楚问题出在哪里,那么像 pdb 这样的调试器可能是下一步的好选择。

根据上面的示例,请考虑以下代码

# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
    # Get the Graph from our traced Module
    g = tracer_class().trace(module)

    """
    Transformations on `g` go here
    """

    return fx.GraphModule(module, g)

# Transform the Graph
transformed = transform_graph(traced)

# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)

使用上面的示例,假设对 print(traced) 的调用向我们显示了转换中的错误。我们想使用调试器找出问题所在。我们启动一个 pdb 会话。我们可以通过在 transform_graph(traced) 上中断,然后按 s 来“单步进入”对 transform_graph(traced) 的调用来查看转换期间发生了什么。

我们也可以通过编辑 print_tabular 方法来打印图中节点的不同属性来获得好运。(例如,我们可能想查看节点的 input_nodesusers。)

可用的调试器

Python 最常用的调试器是 pdb。你可以通过在命令行输入 python -m pdb FILENAME.py 来以“调试模式”启动你的程序,其中 FILENAME 是你想要调试的文件名。之后,你可以使用 pdb调试器命令 来逐步遍历你的正在运行的程序。通常在启动 pdb 时设置断点(b LINE-NUMBER),然后调用 c 以运行程序直到该点。这可以避免你必须逐步遍历每行执行(使用 sn)来到达你想要检查的代码部分。或者,你可以在想要中断的行之前编写 import pdb; pdb.set_trace()。如果你添加了 pdb.set_trace(),那么当你运行程序时,它会自动进入调试模式。(换句话说,你可以只在命令行中输入 python FILENAME.py,而不是 python -m pdb FILENAME.py。)一旦你的文件在调试模式下运行,你就可以逐步遍历代码并使用某些命令检查程序的内部状态。网上有很多关于 pdb 的优秀教程,包括 RealPython 的 “使用 Pdb 进行 Python 调试”

像 PyCharm 或 VSCode 这样的 IDE 通常内置了调试器。在你的 IDE 中,你可以选择 a) 使用 pdb(在你的 IDE 中调出一个终端窗口,例如在 VSCode 中选择“查看”→“终端”),或者 b) 使用内置的调试器(通常是 pdb 的图形化包装器)。

符号追踪的局限性

FX 使用一个 **符号追踪** 系统(也称为 符号执行)来以可转换/可分析的形式捕获程序的语义。该系统是 **追踪** 的,因为它执行程序(实际上是一个 torch.nn.Module 或函数)以记录操作。它是 **符号** 的,因为在执行过程中流经程序的数据不是真实数据,而是符号(在 FX 中称为 Proxy)。

虽然符号追踪适用于大多数神经网络代码,但它也有一些局限性。

动态控制流

符号追踪的主要局限性在于它目前不支持 **动态控制流**。也就是说,循环或 if 语句的条件可能取决于程序的输入值。

例如,让我们检查以下程序

def func_to_trace(x):
    if x.sum() > 0:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
  <...>
  File "dyn.py", line 6, in func_to_trace
    if x.sum() > 0:
  File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
  File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

if 语句的条件依赖于 x.sum() 的值,而 x.sum() 又依赖于 x 的值,x 是函数的输入。由于 x 可能会改变(例如,如果你向被追踪的函数传递一个新的输入张量),这就是 **动态控制流**。回溯会向上遍历你的代码,以显示这种情况发生的位置。

静态控制流

另一方面,所谓的 **静态控制流** 是受支持的。静态控制流是指循环或 if 语句,其值在不同调用之间不会改变。通常,在 PyTorch 程序中,这种控制流是针对根据超参数做出模型架构决策的代码而产生的。举个具体的例子

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # This if-statement is so-called static control flow.
        # Its condition does not depend on any input values
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

if self.do_activation 这个 if 语句不依赖于任何函数输入,因此它是静态的。do_activation 可以被认为是一个超参数,并且具有不同 do_activation 值的不同 MyModule 实例的追踪具有不同的代码。这是一个由符号追踪支持的有效模式。

许多动态控制流的实例在语义上是静态控制流。这些实例可以通过消除对输入值的依赖关系来使其支持符号追踪,例如将值移动到 Module 属性或在符号追踪期间将具体值绑定到参数。

def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

fx.symbolic_trace(f, concrete_args={'flag': True})

在真正动态控制流的情况下,包含此代码的程序部分可以被追踪为对 Method(请参阅 使用 Tracer 类自定义追踪)或函数(请参阅 wrap())的调用,而不是追踪它们本身。

torch 函数

FX 使用 __torch_function__ 作为拦截调用的机制(有关此内容的更多信息,请参阅 技术概述)。某些函数,例如内置 Python 函数或 math 模块中的函数,不受 __torch_function__ 覆盖,但我们仍然希望在符号追踪中捕获它们。例如

import torch
import torch.fx
from math import sqrt

def normalize(x):
    """
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
  <...>
  File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
  File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

错误告诉我们,内置函数 len 不受支持。我们可以使用 wrap() API 使得此类函数在追踪中被记录为直接调用。

torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

使用 Tracer 类自定义追踪

Tracer 类是 symbolic_trace 实现的基础类。可以通过继承 Tracer 来自定义追踪的行为,如下所示

class MyCustomTracer(torch.fx.Tracer):
    # Inside here you can override various methods
    # to customize tracing. See the `Tracer` API
    # reference
    pass


# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

叶子模块

叶子模块是在符号追踪中显示为调用而不是被追踪的模块。叶子模块的默认集合是标准 torch.nn 模块实例的集合。例如

class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

可以通过覆盖 Tracer.is_leaf_module() 来自定义叶子模块的集合。

其他

  • 张量构造器(例如 torch.zerostorch.onestorch.randtorch.randntorch.sparse_coo_tensor)目前不可追踪。

    • 可以使用确定性构造器(zerosones),它们生成的值将作为常量嵌入到追踪中。只有当这些构造器的参数引用动态输入大小时,这才会成为问题。在这种情况下,ones_likezeros_like 可能是一个可行的替代方案。

    • 非确定性构造器(randrandn)将有一个随机值嵌入到追踪中。这可能不是预期的行为。一种解决方法是将 torch.randn 包装在一个 torch.fx.wrap 函数中,然后调用该函数。

    @torch.fx.wrap
    def torch_randn(x, shape):
        return torch.randn(shape)
    
    def f(x):
        return x + torch_randn(x, 5)
    fx.symbolic_trace(f)
    
    • 此行为可能会在将来的版本中修复。

  • 类型注解

    • Python 3 样式的类型注解(例如 func(x : torch.Tensor, y : int) -> torch.Tensor)受支持,并且会由符号追踪保留。

    • Python 2 样式的注释类型注解 # type: (torch.Tensor, int) -> torch.Tensor 目前不受支持。

    • 函数内局部名称上的注解目前不受支持。

  • 关于 training 标志和子模块的问题

    • 当使用诸如 torch.nn.functional.dropout 之类的函数时,将训练参数作为 self.training 传入是很常见的。在 FX 追踪期间,这很可能被烘焙成一个常量值。

    import torch
    import torch.fx
    
    class DropoutRepro(torch.nn.Module):
      def forward(self, x):
        return torch.nn.functional.dropout(x, training=self.training)
    
    
    traced = torch.fx.symbolic_trace(DropoutRepro())
    print(traced.code)
    """
    def forward(self, x):
      dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False);  x = None
      return dropout
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    """
    AssertionError: Tensor-likes are not close!
    
    Mismatched elements: 15 / 15 (100.0%)
    Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed)
    Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed)
    """
    
    • 但是,当使用标准的 nn.Dropout() 子模块时,训练标志会被封装,并且由于保留了 nn.Module 对象模型,因此可以更改。

    class DropoutRepro2(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.drop = torch.nn.Dropout()
    
      def forward(self, x):
        return self.drop(x)
    
    traced = torch.fx.symbolic_trace(DropoutRepro2())
    print(traced.code)
    """
    def forward(self, x):
      drop = self.drop(x);  x = None
      return drop
    """
    
    traced.eval()
    
    x = torch.randn(5, 3)
    torch.testing.assert_close(traced(x), x)
    
  • 由于这种差异,请考虑将与 training 标志动态交互的模块标记为叶子模块。

API 参考

torch.fx.symbolic_trace(root, concrete_args=None)[source]

符号追踪 API

给定一个 nn.Module 或函数实例 root,此函数将返回一个通过记录在跟踪 root 期间看到的操作而构建的 GraphModule

concrete_args 允许您部分专门化您的函数,无论是否要移除控制流或数据结构。

例如

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

由于存在控制流,FX 通常无法跟踪此内容。但是,我们可以使用 concrete_argsb 的值进行专门化以跟踪此内容

f = fx.symbolic_trace(f, concrete_args={'b': False})
assert f(3, False)  == 6

请注意,尽管您仍然可以传入 b 的不同值,但这些值将被忽略。

我们还可以使用 concrete_args 从函数中消除数据结构处理。这将使用 pytrees 来展平您的输入。为了避免过度专门化,请为不应专门化的值传入 fx.PH。例如

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7
参数
  • root (Union[torch.nn.Module, Callable]) – 要跟踪并转换为图形表示的模块或函数。

  • concrete_args (Optional[Dict[str, any]]) – 要部分专门化的输入

返回值

root 记录的操作创建的模块。

返回类型

GraphModule

注意

此 API 的向后兼容性得到保证。

torch.fx.wrap(fn_or_name)[source]

此函数可以在模块级作用域调用,以将 fn_or_name 注册为“叶子函数”。“叶子函数”将在 FX 跟踪中保留为 CallFunction 节点,而不是被跟踪。

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y

torch.fx.wrap('my_custom_function')

def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

此函数也可以等效地用作装饰器。

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

包装后的函数可以被认为是“叶子函数”,类似于“叶子模块”的概念,也就是说,它们是在 FX 跟踪中作为调用保留的函数,而不是被跟踪。

参数

fn_or_name (Union[str, Callable]) – 当调用该函数时,要插入图形中的函数或全局函数的名称

注意

此 API 的向后兼容性得到保证。

class torch.fx.GraphModule(*args, **kwargs)[source]

GraphModule 是从 fx.Graph 生成的 nn.Module。Graphmodule 具有 graph 属性,以及从该 graph 生成的 codeforward 属性。

警告

当重新分配 graph 时,codeforward 将自动重新生成。但是,如果您在不重新分配 graph 属性本身的情况下编辑 graph 的内容,则必须调用 recompile() 来更新生成的代码。

注意

此 API 的向后兼容性得到保证。

__init__(root, graph, class_name='GraphModule')[source]

构造一个 GraphModule。

参数
  • root (Union[torch.nn.Module, Dict[str, Any]) – root 可以是 nn.Module 实例或将字符串映射到任何属性类型的字典。如果 root 是一个模块,则图形的节点的 target 字段中对基于模块的对象(通过限定名称)的任何引用都将从 root 的模块层次结构中的相应位置复制到 GraphModule 的模块层次结构中。如果 root 是一个字典,则节点的 target 中找到的限定名称将在字典的键中直接查找。字典映射到的对象将复制到 GraphModule 模块层次结构中的相应位置。

  • graph (Graph) – graph 包含此 GraphModule 应用于代码生成的节点

  • class_name (str) – name 表示此 GraphModule 的名称,用于调试目的。如果未设置,则所有错误消息都将报告为源自 GraphModule。将其设置为 root 的原始名称或在转换上下文中具有意义的名称可能会有所帮助。

注意

此 API 的向后兼容性得到保证。

add_submodule(target, m)[source]

将给定的子模块添加到 self 中。

如果它们是 target 的子路径,则此操作将在尚不存在的位置安装空的模块。

参数
  • target (str) – 新子模块的完全限定字符串名称(请参阅 nn.Module.get_submodule 中的示例,了解如何指定完全限定字符串。)

  • m (Module) – 子模块本身;我们要安装在当前模块中的实际对象

返回值

子模块是否可以插入。为了

此方法返回 True,由 target 表示的链中的每个对象必须是 a) 尚未存在,或 b) 引用 nn.Module(而不是参数或其他属性)

返回类型

bool

注意

此 API 的向后兼容性得到保证。

property code: str

返回从此 GraphModule 底层的 Graph 生成的 Python 代码。

delete_all_unused_submodules()[source]

self 中删除所有未使用的子模块。

如果以下任何一项为真,则模块被认为是“已使用”:1. 它具有已使用的子项 2. 其前向被通过 call_module 节点直接调用 3. 它具有从 get_attr 节点使用的非模块属性

此方法可以用来清理 nn.Module,而无需手动在每个未使用的子模块上调用 delete_submodule

注意

此 API 的向后兼容性得到保证。

delete_submodule(target)[source]

self 中删除给定的子模块。

如果 target 不是有效目标,则不会删除该模块。

参数

target (str) – 新子模块的完全限定字符串名称(请参阅 nn.Module.get_submodule 中的示例,了解如何指定完全限定字符串。)

返回值

目标字符串是否引用了

我们要删除的子模块。返回值 False 表示 target 不是对子模块的有效引用。

返回类型

bool

注意

此 API 的向后兼容性得到保证。

property graph: Graph

返回此 GraphModule 底层的 Graph

print_readable(print_output=True, include_stride=False, include_device=False, colored=False)[source]

返回为当前 GraphModule 及其子 GraphModule 生成的 Python 代码

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

recompile()[source]

重新编译此 GraphModule,方法是从其 graph 属性开始。在编辑包含的 graph 后应调用此方法,否则此 GraphModule 的生成代码将过时。

注意

此 API 的向后兼容性得到保证。

返回类型

Python代码

to_folder(folder, module_name='FxModule')[source]
将模块转储到 folder 中,并使用 module_name 命名,以便可以使用

from <folder> import <module_name> 导入

参数

folder (Union[str, os.PathLike]): 用于写入代码的文件夹

module_name (str): 在写入代码时用于 Module 的顶级名称

写入代码

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source]

Graph 是 FX 中间表示中使用的主数据结构。它由一系列 Node 组成,每个节点代表调用点(或其他语法结构)。这些 Node 的列表合在一起构成了一个有效的 Python 函数。

例如,以下代码

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

将生成以下图形

print(gm.graph)
graph(x):
    %linear_weight : [num_users=1] = self.linear.weight
    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

有关 Graph 中表示的操作的语义,请参阅 Node

注意

此 API 的向后兼容性得到保证。

__init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source]

构造一个空的图形。

注意

此 API 的向后兼容性得到保证。

call_function(the_function, args=None, kwargs=None, type_expr=None)[source]

Graph 中插入一个 call_function Nodecall_function 节点表示对由 the_function 指定的 Python 可调用对象的调用。

参数
  • the_function (Callable[..., Any]) – 要调用的函数。可以是任何 PyTorch 运算符、Python 函数或 builtinsoperator 命名空间的成员。

  • args (Optional[Tuple[Argument, ...]]) – 要传递给被调用函数的位置参数。

  • kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用函数的关键字参数

  • type_expr (Optional[Any]) – 一个可选的类型注释,表示此节点输出将具有的 Python 类型。

返回值

新创建并插入的 call_function 节点。

返回类型

节点

注意

此方法与 Graph.create_node() 采用相同的插入点和类型表达式规则。

注意

此 API 的向后兼容性得到保证。

call_method(method_name, args=None, kwargs=None, type_expr=None)[source]

Graph 中插入一个 call_method Nodecall_method 节点表示对 args 的第 0 个元素上的给定方法的调用。

参数
  • method_name (str) – 要应用于 self 参数的方法名称。例如,如果 args[0] 是表示 TensorNode,则要对该 Tensor 调用 relu(),请将 relu 传递给 method_name

  • args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的位置参数。请注意,这应该包含一个 self 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的关键字参数

  • type_expr (Optional[Any]) – 一个可选的类型注释,表示此节点输出将具有的 Python 类型。

返回值

新创建并插入的 call_method 节点。

返回类型

节点

注意

此方法与 Graph.create_node() 采用相同的插入点和类型表达式规则。

注意

此 API 的向后兼容性得到保证。

call_module(module_name, args=None, kwargs=None, type_expr=None)[source]

Graph 中插入一个 call_module Nodecall_module 节点表示对 Module 层次结构中 Module 的 forward() 函数的调用。

参数
  • module_name (str) – 要调用的 ModuleModule 层次结构中的限定名称。例如,如果被跟踪的 Module 具有名为 foo 的子模块,该子模块又具有名为 bar 的子模块,则应将限定名称 foo.bar 作为 module_name 传递以调用该模块。

  • args (Optional[Tuple[Argument, ...]]) – 要传递给被调用方法的位置参数。请注意,这不应包含 self 参数。

  • kwargs (Optional[Dict[str, Argument]]) – 要传递给被调用方法的关键字参数

  • type_expr (Optional[Any]) – 一个可选的类型注释,表示此节点输出将具有的 Python 类型。

返回值

新创建并插入的 call_module 节点。

返回类型

节点

注意

此方法与 Graph.create_node() 采用相同的插入点和类型表达式规则。

注意

此 API 的向后兼容性得到保证。

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source]

创建一个 Node 并将其添加到当前插入点的 Graph 中。请注意,可以通过 Graph.inserting_before()Graph.inserting_after() 设置当前插入点。

参数
  • op (str) – 该节点的操作码。可以是 ‘call_function’,‘call_method’,‘get_attr’,‘call_module’,‘placeholder’ 或 ‘output’ 之一。这些操作码的语义在 Graph 的文档字符串中进行了描述。

  • args (Optional[Tuple[Argument, ...]]) – 该节点的参数元组。

  • kwargs (Optional[Dict[str, Argument]]) – 该节点的关键字参数。

  • name (Optional[str]) – Node 的可选字符串名称。这将影响在生成的 Python 代码中分配给该值的名称。

  • type_expr (Optional[Any]) – 一个可选的类型注释,表示此节点输出将具有的 Python 类型。

返回值

新创建并插入的节点。

返回类型

节点

注意

此 API 的向后兼容性得到保证。

eliminate_dead_code(is_impure_node=None)[source]

根据每个节点的用户数量以及节点是否具有任何副作用,从图中删除所有死代码。在调用之前,图必须是拓扑排序的。

参数
  • is_impure_node (Optional[Callable[[Node], bool]]) – 返回以下内容的函数:

  • None (节点是否是不纯的。如果为) –

  • to (则默认行为为) –

  • Node.is_impure. (使用) –

返回值

图是否因传递而发生更改。

返回类型

bool

示例

在消除死代码之前,下面 a = x + 1 中的 a 没有用户,因此可以从图中删除而不会产生影响。

def forward(self, x):
    a = x + 1
    return x + self.attr_1

在消除死代码之后,a = x + 1 已被删除,其余的 forward 保留。

def forward(self, x):
    return x + self.attr_1

警告

死代码消除有一些启发式方法来避免删除有副作用的节点(请参阅 Node.is_impure),但总的来说覆盖率非常差,因此您应该假设除非您知道您的 FX 图完全由函数操作组成或您提供了自己的自定义函数来检测有副作用的节点,否则调用此方法是不安全的。

注意

此 API 的向后兼容性得到保证。

erase_node(to_erase)[source]

Graph 中删除一个 Node。如果 Graph 中仍然存在该节点的用户,则会抛出异常。

参数

to_erase (Node) – 要从 Graph 中删除的 Node

注意

此 API 的向后兼容性得到保证。

find_nodes(*, op, target=None, sort=True)[source]

允许快速查询节点。

参数
  • op (str) – 操作的名称。

  • target (Optional[Target]) – 节点的目标。对于 call_function,目标是必需的。对于其他操作,目标是可选的。

  • sort (bool) – 是否按节点在图中出现的顺序返回节点。

返回值

具有请求的操作和目标的节点的可迭代对象。

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

get_attr(qualified_name, type_expr=None)[source]

get_attr 节点插入到 Graph 中。 get_attr Node 表示从 Module 层次结构中获取属性。

参数
  • qualified_name (str) – 要检索的属性的完全限定名称。例如,如果跟踪的 Module 具有名为 foo 的子模块,该子模块具有名为 bar 的子模块,该子模块具有名为 baz 的属性,则应将完全限定名称 foo.bar.baz 作为 qualified_name 传递。

  • type_expr (Optional[Any]) – 一个可选的类型注释,表示此节点输出将具有的 Python 类型。

返回值

新创建并插入的 get_attr 节点。

返回类型

节点

注意

此方法与 Graph.create_node 采用相同的插入点和类型表达式规则。

注意

此 API 的向后兼容性得到保证。

graph_copy(g, val_map, return_output_node=False)[source]

将给定图中的所有节点复制到 self 中。

参数
  • g (Graph) – 要从中复制节点的源图。

  • val_map (Dict[Node, Node]) – 一个字典,其中将填充来自 g 中的节点到 self 中的节点的映射。请注意,可以将 val_map 与其中已经存在的值一起传递,以覆盖某些值的复制。

返回值

如果 g 有一个 output 节点,则 self 中的值现在等效于 g 中的输出值。否则为 None

返回类型

Optional[Union[Tuple[Any, …], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]

注意

此 API 的向后兼容性得到保证。

inserting_after(n=None)[source]
设置 create_node 和配套方法将插入图中的位置。

在 ‘with’ 语句中使用时,这将临时设置插入点,然后在 ‘with’ 语句退出时恢复它。

with g.inserting_after(n):
    ... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) #  set the insert point permanently

参数

n (Optional[Node]): 要在其之前插入的节点。如果为 None,则将在

整个图的开头插入。

返回值

将在 __exit__ 上恢复插入点的资源管理器。

注意

此 API 的向后兼容性得到保证。

inserting_before(n=None)[source]
设置 create_node 和配套方法将插入图中的位置。

在 ‘with’ 语句中使用时,这将临时设置插入点,然后在 ‘with’ 语句退出时恢复它。

with g.inserting_before(n):
    ... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) #  set the insert point permanently

参数

n (Optional[Node]): 要在其之前插入的节点。如果为 None,则将在

整个图的开头插入。

返回值

将在 __exit__ 上恢复插入点的资源管理器。

注意

此 API 的向后兼容性得到保证。

lint()[source]

对该图运行各种检查,以确保其格式正确。具体包括: - 检查节点是否具有正确的归属权(由该图拥有) - 检查节点是否按拓扑顺序出现 - 如果该图具有拥有 GraphModule,则检查目标是否存在于该 GraphModule 中

注意

此 API 的向后兼容性得到保证。

node_copy(node, arg_transform=<function Graph.<lambda>>)[source]

将一个图中的节点复制到另一个图中。 arg_transform 需要将节点所在图的参数转换为 self 所在图的参数。示例

# Copying all the nodes in `g` into `new_graph`
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
参数
  • node (Node) – 要复制到 self 中的节点。

  • arg_transform (Callable[[Node], Argument]) – 一个函数,用于将节点的 argskwargs 中的 Node 参数转换为 self 中的等效参数。在最简单的情况下,这应该从一个表中检索值,该表将原始图中的节点映射到 self

返回类型

节点

注意

此 API 的向后兼容性得到保证。

property nodes: _node_list

获取构成此图的节点列表。

请注意,此 Node 列表表示是一个双向链表。迭代期间的修改(例如删除节点、添加节点)是安全的。

返回值

节点的双向链表。请注意,可以对该列表调用 reversed 以切换迭代顺序。

on_generate_code(make_transformer)[source]

在生成 Python 代码时注册一个转换器函数

参数
make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc])

一个返回要注册的代码转换器的函数。此函数由 on_generate_code 调用以获取代码转换器。

此函数还将其输入作为当前注册的代码转换器(如果未注册任何内容,则为 None),以防不希望覆盖它。这对于将代码转换器链接在一起很有用。

返回值

一个上下文管理器,当在 with 语句中使用时,自动恢复先前注册的代码转换器。

示例

gm: fx.GraphModule = ...

# This is a code transformer we want to register. This code
# transformer prepends a pdb import and trace statement at the very
# beginning of the generated torch.fx code to allow for manual
# debugging with the PDB library.
def insert_pdb(body):
    return ["import pdb; pdb.set_trace()\n", *body]

# Registers `insert_pdb`, and overwrites the current registered
# code transformer (given by `_` to the lambda):
gm.graph.on_generate_code(
    lambda _: insert_pdb
)

# Or alternatively, registers a code transformer which first
# runs `body` through existing registered transformer, then
# through `insert_pdb`:
gm.graph.on_generate_code(
    lambda current_trans: (
        lambda body: insert_pdb(
            current_trans(body) if current_trans
            else body
        )
    )
)

gm.recompile()
gm(*inputs)  # drops into pdb

此函数还可以用作上下文管理器,好处是可以自动恢复先前注册的代码转换器

# ... continue from previous example

with gm.graph.on_generate_code(lambda _: insert_pdb):
    # do more stuff with `gm`...
    gm.recompile()
    gm(*inputs)  # drops into pdb

# now previous code transformer is restored (but `gm`'s code with pdb
# remains - that means you can run `gm` with pdb here too, until you
# run next `recompile()`).

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

output(result, type_expr=None)[source]

将一个 output Node 插入到 Graph 中。一个 output 节点表示 Python 代码中的 return 语句。 result 是应该返回的值。

参数
  • result (Argument) – 要返回的值。

  • type_expr (Optional[Any]) – 一个可选的类型注释,表示此节点输出将具有的 Python 类型。

注意

此方法与 Graph.create_node 采用相同的插入点和类型表达式规则。

注意

此 API 的向后兼容性得到保证。

placeholder(name, type_expr=None, default_value)[source]

将一个 placeholder 节点插入到图中。一个 placeholder 表示函数输入。

参数
  • name (str) – 输入值的名称。这对应于此 Graph 表示的函数的位置参数的名称。

  • type_expr (Optional[Any]) – 一个可选的类型注解,表示此节点输出将具有的 Python 类型。在某些情况下,这对于正确的代码生成是必需的(例如,当该函数随后用于 TorchScript 编译时)。

  • default_value (Any) – 此函数参数应采用的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 作为此参数传递,以指定参数_没有_默认值。

返回类型

节点

注意

此方法与 Graph.create_node 采用相同的插入点和类型表达式规则。

注意

此 API 的向后兼容性得到保证。

print_tabular()[source]

以表格格式打印图的中间表示。请注意,此 API 需要安装 tabulate 模块。

注意

此 API 的向后兼容性得到保证。

process_inputs(*args)[source]

处理参数,以便可以将其传递到 FX 图中。

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

process_outputs(out)[source]

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False)[source]

将此 Graph 转换为有效的 Python 代码。

参数

root_module (str) – 要在其上查找限定名称目标的根模块的名称。这通常是“self”。

返回值

src:表示对象的 Python 源代码 globals:src 中全局名称的字典 -> 它们引用的对象。

返回类型

一个 PythonCode 对象,包含两个字段

注意

此 API 的向后兼容性得到保证。

set_codegen(codegen)[source]

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source]

Node 是表示 Graph 中单个操作的数据结构。在大多数情况下,节点表示对各种实体的调用点,例如运算符、方法和模块(一些例外包括指定函数输入和输出的节点)。每个 Node 都具有由其 op 属性指定的函数。Node 对每个 op 值的语义如下

  • placeholder 表示函数输入。 name 属性指定此值将采用的名称。 target 同样是参数的名称。 args 包含:1)无,或 2)一个表示函数输入的默认参数的单个参数。 kwargs 是无关紧要的。占位符对应于图打印输出中的函数参数(例如 x)。

  • get_attr 从模块层次结构中检索参数。 name 同样是获取结果分配的名称。 target 是参数在模块层次结构中的位置的完全限定名称。 argskwargs 是无关紧要的

  • call_function 将自由函数应用于某些值。 name 同样是要分配的值的名称。 target 是要应用的函数。 argskwargs 表示函数的参数,遵循 Python 调用约定

  • call_module 将模块层次结构中的 forward() 方法应用于给定的参数。 name 与之前相同。 target 是要在模块层次结构中调用的模块的完全限定名称。 argskwargs 表示要用于调用模块的参数,不包括 self 参数

  • call_method 在某个值上调用方法。 name 类似。 target 是要应用于 self 参数的方法的字符串名称。 argskwargs 表示要用于调用模块的参数,包括 self 参数

  • output 在其 args[0] 属性中包含已跟踪函数的输出。这对应于 Graph 打印输出中的“return”语句。

注意

此 API 的向后兼容性得到保证。

property all_input_nodes: List[Node]

返回所有作为此节点输入的节点。这等效于遍历 argskwargs,并仅收集是节点的值。

返回值

Nodeargskwargs 中出现的 Nodes 列表,按顺序排列。

append(x)[source]

在图中节点列表中,将 x 插入到此节点之后。等效于 self.next.prepend(x)

参数

x (Node) – 要放在此节点之后的节点。必须是同一图的成员。

注意

此 API 的向后兼容性得到保证。

property args: Tuple[Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]], ...]

Node 的参数元组。参数的解释取决于节点的操作码。有关更多信息,请参阅 Node 的文档字符串。

允许对该属性进行赋值。所有使用和用户的核算会在赋值时自动更新。

format_node(placeholder_names=None, maybe_return_typename=None)[source]

返回 self 的描述性字符串表示形式。

此方法可以用作调试工具,无需任何参数。

此函数还在 Graph__str__ 方法中内部使用。 placeholder_namesmaybe_return_typename 中的字符串共同构成了此 Graph 所在 GraphModule 中自动生成的 forward 函数的签名。 placeholder_namesmaybe_return_typename 不应以其他方式使用。

参数
  • placeholder_names (Optional[List[str]]) – 一个列表,用于存储表示在生成的 forward 函数中的占位符的格式化字符串。仅供内部使用。

  • maybe_return_typename (Optional[List[str]]) – 一个单元素列表,用于存储表示生成的 forward 函数输出的格式化字符串。仅供内部使用。

返回值

如果 1) 我们正在 Graph__str__ 方法中使用 format_node 作为内部辅助函数,

并且 2) self 是一个占位符节点,则返回 None。否则,返回当前节点的描述性字符串表示形式。

返回类型

str

注意

此 API 的向后兼容性得到保证。

insert_arg(idx, arg)[source]

在给定索引处插入一个位置参数到参数列表中。

参数
  • idx (int) – 要在其之前插入元素的 self.args 中的索引。

  • arg (Argument) – 要插入到 args 中的新参数值。

注意

此 API 的向后兼容性得到保证。

is_impure()[source]

返回此操作是否是不纯的,即其操作是否为占位符或输出,或者是否是调用不纯的 call_function 或 call_module。

返回值

操作是否是不纯的。

返回类型

bool

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

property kwargs: Dict[str, Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]

Node 的关键字参数字典。参数的解释取决于节点的操作码。有关更多信息,请参阅 Node 的文档字符串。

允许对该属性进行赋值。所有使用和用户的核算会在赋值时自动更新。

property next: Node

返回节点链接列表中的下一个 Node

返回值

节点链接列表中的下一个 Node

normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source]

返回规范化后的 Python 目标参数。这意味着 args/kwargs 将与模块/函数的签名相匹配,如果 normalize_to_only_use_kwargs 为真,则以位置顺序返回仅包含 kwargs 的参数。还会填充默认值。不支持仅限位置的参数或可变参数。

支持模块调用。

可能需要 arg_typeskwarg_types 来消除重载歧义。

参数
  • root (torch.nn.Module) – 用于解析模块目标的模块。

  • arg_types (Optional[Tuple[Any]]) – args 的参数类型元组

  • kwarg_types (Optional[Dict[str, Any]]) – kwargs 的参数类型字典

  • normalize_to_only_use_kwargs (bool) – 是否规范化为仅使用 kwargs。

返回值

返回命名元组 ArgsKwargsPair,如果未成功则返回 None

返回类型

Optional[ArgsKwargsPair]

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

prepend(x)[source]

在图中的节点列表中,将 x 插入到此节点之前。示例

Before: p -> self
        bx -> x -> ax
After:  p -> x -> self
        bx -> ax
参数

x (Node) – 要放在此节点之前的节点。必须是同一图的成员。

注意

此 API 的向后兼容性得到保证。

property prev: Node

返回节点链接列表中的上一个 Node

返回值

节点链接列表中的上一个 Node

replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[source]

用节点 replace_with 替换图中所有对 self 的使用。

参数
  • replace_with (Node) – 用于替换所有对 self 使用的节点。

  • delete_user_cb (Callable) – 用于确定是否应删除 self 节点的给定用户的回调函数。

  • propagate_meta (bool) – 是否将原始节点的 .meta 字段上的所有属性复制到替换节点上。出于安全考虑,仅当替换节点尚不存在 .meta 字段时,此操作才有效。

返回值

已对此更改进行操作的节点列表。

返回类型

List[Node]

注意

此 API 的向后兼容性得到保证。

replace_input_with(old_input, new_input)[source]

循环遍历 self 的输入节点,并将所有 old_input 实例替换为 new_input

参数
  • old_input (Node) – 要替换的旧输入节点。

  • new_input (Node) – 用于替换 old_input 的新输入节点。

注意

此 API 的向后兼容性得到保证。

property stack_trace: Optional[str]

返回在跟踪期间记录的 Python 堆栈跟踪(如果有)。当使用 fx.Tracer 进行跟踪时,此属性通常由 Tracer.create_proxy 填充。为了在跟踪期间出于调试目的记录堆栈跟踪,请在 Tracer 实例上设置 record_stack_traces = True。当使用 dynamo 进行跟踪时,此属性将默认由 OutputGraph.create_proxy 填充。

stack_trace 将字符串末尾处具有最内部的帧。

update_arg(idx, arg)[source]

更新现有位置参数以包含新值 arg。调用后,self.args[idx] == arg

参数
  • idx (int) – 要更新的元素在 self.args 中的索引

  • arg (Argument) – 要写入 args 的新参数值

注意

此 API 的向后兼容性得到保证。

update_kwarg(key, arg)[source]

更新现有的关键字参数,使其包含新的值 arg。调用后,self.kwargs[key] == arg 将成立。

参数
  • key (str) – self.kwargs 中要更新的元素的键

  • arg (Argument) – 要写入 kwargs 的新参数值

注意

此 API 的向后兼容性得到保证。

class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source]

Tracer 是实现 torch.fx.symbolic_trace 符号跟踪功能的类。调用 symbolic_trace(m) 等效于 Tracer().trace(m)

Tracer 可以被子类化以覆盖跟踪过程中的各种行为。可以覆盖的不同行为在该类的​​方法的文档字符串中进行了描述。

注意

此 API 的向后兼容性得到保证。

call_module(m, forward, args, kwargs)[source]

此方法指定当 Tracer 遇到对 nn.Module 实例的调用时的行为。

默认情况下,行为是通过 is_leaf_module 检查被调用的模块是否为叶子模块。如果是,则在 Graph 中发出一个引用 mcall_module 节点。否则,正常调用 Module,跟踪其 forward 函数中的操作。

此方法可以被覆盖以(例如)创建嵌套的已跟踪 GraphModule,或者您在跨越 Module 边界进行跟踪时所需的任何其他行为。

参数
  • m (Module) – 发出调用的模块

  • forward (Callable) – 要调用的 Module 的 forward() 方法

  • args (Tuple) – 模块调用处的参数

  • kwargs (Dict) – 模块调用处的关键字参数

返回值

来自 Module 调用的返回值。在发出 call_module 节点的情况下,这是一个 Proxy 值。否则,它是从 Module 调用返回的任何值。

返回类型

任意

注意

此 API 的向后兼容性得到保证。

create_arg(a)[source]

一种方法,用于指定在准备用作 Graph 中节点的参数的值时跟踪的行为。

默认情况下,行为包括

  1. 遍历集合类型(例如元组、列表、字典)并在元素上递归调用 create_args

  2. 给定一个 Proxy 对象,返回对底层 IR Node 的引用

  3. 给定一个非 Proxy 张量对象,为各种情况发出 IR

    • 对于参数,发出一个引用该参数的 get_attr 节点

    • 对于非参数张量,将张量存储在一个引用该属性的特殊属性中。

此方法可以被覆盖以支持更多类型。

参数

a (Any) – 要作为 Graph 中的 Argument 发出的值。

返回值

将值 a 转换为相应的 Argument

返回类型

Argument

注意

此 API 的向后兼容性得到保证。

create_args_for_root(root_fn, is_module, concrete_args=None)[source]

创建对应于 root 模块签名的 placeholder 节点。此方法内省根的签名并相应地发出这些节点,还支持 *args**kwargs

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

create_node(kind, target, args, kwargs, name=None, type_expr=None)

给定目标、参数、关键字参数和名称插入图节点。

此方法可以被覆盖以对节点创建中使用的值进行额外的检查、验证或修改。例如,可能希望不允许记录就地操作。

注意

此 API 的向后兼容性得到保证。

返回类型

节点

create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)

根据给定的参数创建节点,然后返回用 Proxy 对象包装的节点。

如果 kind = 'placeholder',则我们正在创建一个表示函数参数的节点。如果我们需要编码默认参数,则使用 args 元组。对于 placeholder 节点,否则 args 为空。

注意

此 API 的向后兼容性得到保证。

get_fresh_qualname(prefix)[source]

获取前缀的新名称并返回它。此函数确保它不会与图上现有的属性冲突。

注意

此 API 的向后兼容性得到保证。

返回类型

str

getattr(attr, attr_val, parameter_proxy_cache)[source]

此方法指定当我们在对 nn.Module 实例的调用上调用 getattr 时,此 Tracer 的行为。

默认情况下,行为是返回属性的代理值。它还将代理值存储在 parameter_proxy_cache 中,以便未来的调用重用代理而不是创建新的代理。

此方法可以被覆盖以(例如)在查询参数时不返回代理。

参数
  • attr (str) – 要查询的属性的名称

  • attr_val (Any) – 属性的值

  • parameter_proxy_cache (Dict[str, Any]) – 属性名称到代理的缓存

返回值

来自 getattr 调用的返回值。

警告

此 API 为实验性 API,并且**不**保证向后兼容性。

is_leaf_module(m, module_qualified_name)[source]

一种方法,用于指定给定的 nn.Module 是否为“叶子”模块。

叶子模块是在 IR 中出现的原子单元,由 call_module 调用引用。默认情况下,PyTorch 标准库命名空间 (torch.nn) 中的模块是叶子模块。所有其他模块都将被跟踪,并记录其组成操作,除非通过此参数另行指定。

参数
  • m (Module) – 要查询的模块

  • module_qualified_name (str) – 此模块的根路径。例如,如果您有一个模块层次结构,其中子模块 foo 包含子模块 bar,而 bar 又包含子模块 baz,则该模块将在此处显示为限定名称 foo.bar.baz

返回类型

bool

注意

此 API 的向后兼容性得到保证。

iter(obj)
当代理对象被迭代时调用,例如

在控制流中使用时。通常我们不知道该怎么做,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回迭代器。

注意

此 API 的向后兼容性得到保证。

返回类型

迭代器

keys(obj)
当代理对象调用 keys() 方法时调用。

当在代理上调用 ** 时会发生这种情况。这应该返回一个迭代器,如果 ** 应该在您的自定义跟踪器中工作。

注意

此 API 的向后兼容性得到保证。

返回类型

任意

path_of_module(mod)[source]

辅助方法,用于在 root 的模块层次结构中查找 mod 的限定名称。例如,如果 root 有一个名为 foo 的子模块,该子模块又有一个名为 bar 的子模块,将 bar 传递给此函数将返回字符串“foo.bar”。

参数

mod (str) – 要检索限定名称的 Module

返回类型

str

注意

此 API 的向后兼容性得到保证。

proxy(node)

注意

此 API 的向后兼容性得到保证。

返回类型

代理

to_bool(obj)
当代理对象被转换为布尔值时调用,例如

在控制流中使用时。通常我们不知道该怎么做,因为我们不知道代理的值,但自定义跟踪器可以使用 create_node 将更多信息附加到图节点,并可以选择返回一个值。

注意

此 API 的向后兼容性得到保证。

返回类型

bool

trace(root, concrete_args=None)[source]

跟踪 root 并返回相应的 FX Graph 表示。root 可以是 nn.Module 实例或 Python 可调用对象。

请注意,在此调用之后,self.root 可能与此处传入的 root 不同。例如,当将自由函数传递给 trace() 时,我们将创建一个 nn.Module 实例用作根并向其中添加嵌入式常量。

参数
  • root (Union[Module, Callable]) – 要跟踪的 Module 或函数。此参数的向后兼容性得到保证。

  • concrete_args (Optional[Dict[str, any]]) – 不应被视为代理的具体参数。此参数为实验性参数,其向后兼容性保证。

返回值

表示传入 root 语义的 Graph

返回类型

注意

此 API 的向后兼容性得到保证。

class torch.fx.Proxy(node, tracer=None)[source]

Proxy 对象是 Node 包装器,在符号跟踪期间流经程序,并将所有它们触及的操作(torch 函数调用、方法调用、运算符)记录到不断增长的 FX 图中。

如果您正在执行图转换,则可以在原始 Node 周围包装您自己的 Proxy 方法,以便您可以使用重载运算符向 Graph 添加其他内容。

Proxy 对象不能被迭代。换句话说,如果在循环中或作为 *args/**kwargs 函数参数使用 Proxy,符号跟踪器将抛出错误。

有两种主要方法可以解决这个问题:1. 将不可跟踪的逻辑分解到顶级函数中,并对其使用 fx.wrap。2. 如果控制流是静态的(即循环次数基于某些超参数),则可以保留代码的原始位置,并将其重构为类似以下内容

for i in range(self.some_hyperparameter):
    indexed_item = proxied_value[i]

有关代理内部的更详细说明,请查看 torch/fx/README.md 中的“代理”部分。

注意

此 API 的向后兼容性得到保证。

class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[source]

解释器逐节点执行 FX 图。此模式可用于许多事情,包括编写代码转换以及分析过程。

可以重写解释器类中的方法以自定义执行的行为。根据调用层次结构的可重写方法映射

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

示例

假设我们希望将所有 torch.neg 实例与 torch.sigmoid 实例互换(包括它们的 Tensor 方法等效项)。我们可以像这样子类化解释器

class NegSigmSwapInterpreter(Interpreter):
    def call_function(self, target : Target,
                      args : Tuple, kwargs : Dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : Target,
                    args : Tuple, kwargs : Dict) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_close(result, torch.neg(input).sigmoid())
参数
  • module (torch.nn.Module) – 要执行的模块

  • garbage_collect_values (bool) – 是否在模块执行中最后一次使用值后删除它们。这确保了执行期间的最佳内存使用。可以禁用此功能,例如,通过查看 Interpreter.env 属性来检查执行中的所有中间值。

  • graph (Optional[Graph]) – 如果传递,解释器将执行此图而不是 module.graph,使用提供的 module 参数来满足对状态的任何请求。

注意

此 API 的向后兼容性得到保证。

boxed_run(args_list)[source]

通过解释运行 module 并返回结果。这使用“boxed”调用约定,您传递一个参数列表,该列表将由解释器清除。这确保了输入张量会立即释放。

注意

此 API 的向后兼容性得到保证。

call_function(target, args, kwargs)[source]

执行 call_function 节点并返回结果。

参数
  • target (Target) – 此节点的调用目标。有关语义详细信息,请参见 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

任意

返回值

Any: 函数调用返回的值

注意

此 API 的向后兼容性得到保证。

call_method(target, args, kwargs)[source]

执行 call_method 节点并返回结果。

参数
  • target (Target) – 此节点的调用目标。有关语义详细信息,请参见 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

任意

返回值

Any: 方法调用返回的值

注意

此 API 的向后兼容性得到保证。

call_module(target, args, kwargs)[source]

执行 call_module 节点并返回结果。

参数
  • target (Target) – 此节点的调用目标。有关语义详细信息,请参见 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

任意

返回值

Any: 模块调用返回的值

注意

此 API 的向后兼容性得到保证。

fetch_args_kwargs_from_env(n)[source]

从当前执行环境中获取节点 nargskwargs 的具体值。

参数

n (Node) – 应该为其获取 argskwargs 的节点。

返回值

argskwargs 具有 n 的具体值。

返回类型

Tuple[Tuple, Dict]

注意

此 API 的向后兼容性得到保证。

fetch_attr(target)[source]

self.moduleModule层次结构中获取属性。

参数

target (str) – 要获取的属性的完全限定名称

返回值

属性的值。

返回类型

任意

注意

此 API 的向后兼容性得到保证。

get_attr(target, args, kwargs)[source]

执行get_attr节点。将从self.moduleModule层次结构中检索属性值。

参数
  • target (Target) – 此节点的调用目标。有关语义详细信息,请参见 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回值

检索到的属性的值

返回类型

任意

注意

此 API 的向后兼容性得到保证。

map_nodes_to_values(args, n)[source]

递归遍历args,并在当前执行环境中查找每个Node的具体值。

参数
  • args (Argument) – 在其中查找具体值的数据结构

  • n (Node) – args所属的节点。这仅用于错误报告。

返回类型

Optional[Union[Tuple[Any, …], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]

注意

此 API 的向后兼容性得到保证。

output(target, args, kwargs)[source]

执行output节点。这实际上只是检索output节点引用的值并返回它。

参数
  • target (Target) – 此节点的调用目标。有关语义详细信息,请参见 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回值

输出节点引用的返回值

返回类型

任意

注意

此 API 的向后兼容性得到保证。

placeholder(target, args, kwargs)[source]

执行placeholder节点。请注意,这是有状态的:Interpreter维护一个传递给run的参数的内部迭代器,此方法返回该迭代器上的next()。

参数
  • target (Target) – 此节点的调用目标。有关语义详细信息,请参见 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回值

检索到的参数值。

返回类型

任意

注意

此 API 的向后兼容性得到保证。

run(*args, initial_env=None, enable_io_processing=True)[source]

通过解释运行module并返回结果。

参数
  • *args – 要运行的Module的参数,按位置顺序排列

  • initial_env (Optional[Dict[Node, Any]]) – 可选的执行起始环境。这是一个将Node映射到任何值的字典。例如,这可以用来预先填充某些Nodes的结果,以便在解释器中仅进行部分评估。

  • enable_io_processing (bool) – 如果为真,则在使用输入和输出之前,我们首先使用图的process_inputs和process_outputs函数处理它们。

返回值

执行Module返回的值

返回类型

任意

注意

此 API 的向后兼容性得到保证。

run_node(n)[source]

运行特定的节点n并返回结果。根据node.op调用占位符、get_attr、call_function、call_method、call_module或output。

参数

n (Node) – 要执行的节点

返回值

执行n的结果

返回类型

任意

注意

此 API 的向后兼容性得到保证。

class torch.fx.Transformer(module)[source]

Transformer是一种特殊的解释器类型,它生成一个新的Module。它公开了transform()方法,该方法返回转换后的ModuleTransformer不需要运行参数,就像Interpreter一样。Transformer完全以符号方式工作。

示例

假设我们要将所有torch.neg的实例与torch.sigmoid互换(包括它们的Tensor方法等效项)。我们可以像这样继承Transformer

class NegSigmSwapXformer(Transformer):
    def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)

transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
参数

module (GraphModule) – 要转换的Module

注意

此 API 的向后兼容性得到保证。

call_function(target, args, kwargs)[source]

注意

此 API 的向后兼容性得到保证。

返回类型

任意

call_module(target, args, kwargs)[source]

注意

此 API 的向后兼容性得到保证。

返回类型

任意

get_attr(target, args, kwargs)[source]

执行get_attr节点。在Transformer中,覆盖此方法以将新的get_attr节点插入输出图。

参数
  • target (Target) – 此节点的调用目标。有关语义详细信息,请参见 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

代理

注意

此 API 的向后兼容性得到保证。

placeholder(target, args, kwargs)[source]

执行placeholder节点。在Transformer中,覆盖此方法以将新的placeholder插入输出图。

参数
  • target (Target) – 此节点的调用目标。有关语义详细信息,请参见 Node

  • args (Tuple) – 此调用的位置参数元组

  • kwargs (Dict) – 此调用的关键字参数字典

返回类型

代理

注意

此 API 的向后兼容性得到保证。

transform()[source]

转换self.module并返回转换后的GraphModule

注意

此 API 的向后兼容性得到保证。

返回类型

GraphModule

torch.fx.replace_pattern(gm, pattern, replacement)[source]

在GraphModule(gm)的图中匹配所有可能的非重叠运算符集及其数据依赖项(pattern),然后将每个匹配的子图替换为另一个子图(replacement)。

参数
返回值

表示原始图中 pattern 匹配位置的 Match 对象列表。如果未找到匹配项,则列表为空。 Match 定义如下:

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]

返回类型

List[Match]

示例

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上述代码将首先在 traced_moduleforward 方法中匹配 pattern。模式匹配基于用-定义关系,而不是节点名称。例如,如果您在 pattern 中有 p = torch.cat([a, b]),则可以在原始 forward 函数中匹配 m = torch.cat([a, b]),尽管变量名称不同(pm)。

pattern 中的 return 语句仅基于其值进行匹配;它可能匹配也可能不匹配较大图中的 return 语句。换句话说,模式不必扩展到较大图的末尾。

匹配模式后,它将从较大函数中删除,并替换为 replacement。如果在较大函数中有多个 pattern 匹配项,则将替换每个不重叠的匹配项。在匹配重叠的情况下,将替换找到的第一个重叠匹配项。(“第一个”在此定义为节点用-定义关系的拓扑排序中的第一个。在大多数情况下,第一个节点是紧随 self 之后的参数,而最后一个节点是函数返回的内容。)

需要注意的一点是,pattern Callable 的参数必须在 Callable 本身中使用,并且 replacement Callable 的参数必须与模式匹配。第一个规则是为什么在上面的代码块中,forward 函数具有参数 x, w1, w2,但 pattern 函数仅具有参数 w1, w2pattern 未使用 x,因此不应将其指定为参数。作为第二个规则的示例,请考虑替换

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

def replacement(x, y):
    return torch.relu(x)

在这种情况下,replacement 需要与 pattern 相同数量的参数(xy),即使参数 yreplacement 中未使用。

调用 subgraph_rewriter.replace_pattern 后,生成的 Python 代码如下所示

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

注意

此 API 的向后兼容性得到保证。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取适合初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源