快捷方式

ExportDB

ExportDB 是一个支持和不支持导出案例的集中数据集。它面向希望具体了解哪些类型的代码受支持、导出的细微差别以及如何修改其现有代码以与导出兼容的用户。请注意,这不是 exportdb 支持的所有内容的详尽列表,但它涵盖了用户会遇到的最常见和最令人困惑的用例。

如果您有一个您认为需要我们提供更强保证才能在导出中支持的功能,请在 pytorch/pytorch 存储库中创建一个带有 module:export 标签的问题。

已支持

assume_constant_result

注意

标签:torch.escape-hatch

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch
import torch._dynamo as torchdynamo


class AssumeConstantResult(torch.nn.Module):
    """
    Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
    """

    @torchdynamo.assume_constant_result
    def get_item(self, y):
        return y.int().item()

    def forward(self, x, y):
        return x[: self.get_item(y)]

example_args = (torch.randn(3, 2), torch.tensor(4))
tags = {"torch.escape-hatch"}
model = AssumeConstantResult()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]", y: "i64[]"):
                 slice_1: "f32[3, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 4);  x = None
            return (slice_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_1'), target=None)])
Range constraints: {}

autograd_function

注意

标签

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class MyAutogradFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x.clone()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output + 1

class AutogradFunction(torch.nn.Module):
    """
    TorchDynamo does not keep track of backward() on autograd functions. We recommend to
    use `allow_in_graph` to mitigate this problem.
    """

    def forward(self, x):
        return MyAutogradFunction.apply(x)

example_args = (torch.randn(3, 2),)
model = AutogradFunction()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 clone: "f32[3, 2]" = torch.ops.aten.clone.default(x);  x = None
            return (clone,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='clone'), target=None)])
Range constraints: {}

class_method

注意

标签

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class ClassMethod(torch.nn.Module):
    """
    Class methods are inlined during tracing.
    """

    @classmethod
    def method(cls, x):
        return x + 1

    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(4, 2)

    def forward(self, x):
        x = self.linear(x)
        return self.method(x) * self.__class__.method(x) * type(self).method(x)

example_args = (torch.randn(3, 4),)
model = ClassMethod()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_linear_weight: "f32[2, 4]", p_linear_bias: "f32[2]", x: "f32[3, 4]"):
                 linear: "f32[3, 2]" = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias);  x = p_linear_weight = p_linear_bias = None

                 add: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1)
            add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1)

                 mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add, add_1);  add = add_1 = None

                 add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(linear, 1);  linear = None

                 mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(mul, add_2);  mul = add_2 = None
            return (mul_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_weight'), target='linear.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_linear_bias'), target='linear.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)])
Range constraints: {}

cond_branch_class_method

注意

标签:torch.dynamic-shapetorch.cond

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import cond

class MySubModule(torch.nn.Module):
    def foo(self, x):
        return x.cos()

    def forward(self, x):
        return self.foo(x)

class CondBranchClassMethod(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables


    This example demonstrates using class method in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def __init__(self) -> None:
        super().__init__()
        self.subm = MySubModule()

    def bar(self, x):
        return x.sin()

    def forward(self, x):
        return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])

example_args = (torch.randn(3),)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
model = CondBranchClassMethod()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3]"):
                 sin: "f32[3]" = torch.ops.aten.sin.default(x);  x = None
            return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {}

cond_branch_nested_function

注意

标签:torch.dynamic-shapetorch.cond

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import cond

class CondBranchNestedFunction(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables

    This example demonstrates using nested function in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def forward(self, x):
        def true_fn(x):
            def inner_true_fn(y):
                return x + y

            return inner_true_fn(x)

        def false_fn(x):
            def inner_false_fn(y):
                return x - y

            return inner_false_fn(x)

        return cond(x.shape[0] < 10, true_fn, false_fn, [x])

example_args = (torch.randn(3),)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
model = CondBranchNestedFunction()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3]"):
                 add: "f32[3]" = torch.ops.aten.add.Tensor(x, x);  x = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

cond_branch_nonlocal_variables

注意

标签:torch.dynamic-shapetorch.cond

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import cond

class CondBranchNonlocalVariables(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
    - both branches must take the same args, which must also match the branch args passed to cond.
    - both branches must return a single tensor
    - returned tensor must have the same tensor metadata, e.g. shape and dtype
    - branch function can be free function, nested function, lambda, class methods
    - branch function can not have closure variables
    - no inplace mutations on inputs or global variables

    This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.

    The code below will not work because capturing closure variables is not supported.
    ```
    my_tensor_var = x + 100
    my_primitive_var = 3.14

    def true_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y + my_tensor_var + my_primitive_var

    def false_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y - my_tensor_var - my_primitive_var

    return cond(x.shape[0] > 5, true_fn, false_fn, [x])
    ```

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def forward(self, x):
        my_tensor_var = x + 100
        my_primitive_var = 3.14

        def true_fn(x, y, z):
            return x + y + z

        def false_fn(x, y, z):
            return x - y - z

        return cond(
            x.shape[0] > 5,
            true_fn,
            false_fn,
            [x, my_tensor_var, torch.tensor(my_primitive_var)],
        )

example_args = (torch.randn(6),)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
model = CondBranchNonlocalVariables()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, c_lifted_tensor_0: "f32[]", x: "f32[6]"):
                 add: "f32[6]" = torch.ops.aten.add.Tensor(x, 100)

                 lift_fresh_copy: "f32[]" = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0);  c_lifted_tensor_0 = None
            detach: "f32[]" = torch.ops.aten.detach.default(lift_fresh_copy);  lift_fresh_copy = None

                 add_1: "f32[6]" = torch.ops.aten.add.Tensor(x, add);  x = add = None
            add_2: "f32[6]" = torch.ops.aten.add.Tensor(add_1, detach);  add_1 = detach = None
            return (add_2,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c_lifted_tensor_0'), target='lifted_tensor_0', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {}

cond_closed_over_variable

注意

标签:torch.condpython.closure

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import cond

class CondClosedOverVariable(torch.nn.Module):
    """
    torch.cond() supports branches closed over arbitrary variables.
    """

    def forward(self, pred, x):
        def true_fn(val):
            return x * 2

        def false_fn(val):
            return x - 2

        return cond(pred, true_fn, false_fn, [x + 1])

example_args = (torch.tensor(True), torch.randn(3, 2))
tags = {"torch.cond", "python.closure"}
model = CondClosedOverVariable()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, pred: "b8[]", x: "f32[3, 2]"):
                 true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(pred, true_graph_0, false_graph_0, [x]);  pred = true_graph_0 = false_graph_0 = x = None
            getitem: "f32[3, 2]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                         mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, 2);  x = None
                return (mul,)

        class false_graph_0(torch.nn.Module):
            def forward(self, x: "f32[3, 2]"):
                         sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(x, 2);  x = None
                return (sub,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='pred'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

cond_operands

注意

标签:torch.dynamic-shapetorch.cond

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

from torch.export import Dim
from functorch.experimental.control_flow import cond

x = torch.randn(3, 2)
y = torch.randn(2)
dim0_x = Dim("dim0_x")

class CondOperands(torch.nn.Module):
    """
    The operands passed to cond() must be:
    - a list of tensors
    - match arguments of `true_fn` and `false_fn`

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def forward(self, x, y):
        def true_fn(x, y):
            return x + y

        def false_fn(x, y):
            return x - y

        return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])

example_args = (x, y)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
extra_inputs = (torch.randn(2, 2), torch.randn(2))
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
model = CondOperands()


torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, 2]", y: "f32[2]"):
             #
            sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)

                 gt: "Sym(s0 > 2)" = sym_size_int_1 > 2;  sym_size_int_1 = None

                 true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, y]);  gt = true_graph_0 = false_graph_0 = x = y = None
            getitem: "f32[s0, 2]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, x: "f32[s0, 2]", y: "f32[2]"):
                         add_3: "f32[s0, 2]" = torch.ops.aten.add.Tensor(x, y);  x = y = None
                return (add_3,)

        class false_graph_0(torch.nn.Module):
            def forward(self, x: "f32[s0, 2]", y: "f32[2]"):
                         sub_1: "f32[s0, 2]" = torch.ops.aten.sub.Tensor(x, y);  x = y = None
                return (sub_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {s0: VR[0, int_oo]}

cond_predicate

注意

标签:torch.dynamic-shapetorch.cond

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import cond

class CondPredicate(torch.nn.Module):
    """
    The conditional statement (aka predicate) passed to cond() must be one of the following:
      - torch.Tensor with a single element
      - boolean expression

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def forward(self, x):
        pred = x.dim() > 2 and x.shape[2] > 10

        return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

example_args = (torch.randn(6, 4, 3),)
tags = {
    "torch.cond",
    "torch.dynamic-shape",
}
model = CondPredicate()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[6, 4, 3]"):
                 sin: "f32[6, 4, 3]" = torch.ops.aten.sin.default(x);  x = None
            return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {}

constrain_as_size_example

注意

标签:torch.escape-hatchtorch.dynamic-value

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch


class ConstrainAsSizeExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at torch._check and torch._check_is_size APIs.
    torch._check_is_size is used for values that NEED to be used for constructing
    tensor.
    """

    def forward(self, x):
        a = x.item()
        torch._check_is_size(a)
        torch._check(a <= 5)
        return torch.zeros((a, 5))


example_args = (torch.tensor(4),)
tags = {
    "torch.dynamic-value",
    "torch.escape-hatch",
}
model = ConstrainAsSizeExample()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]"):
                 item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None

             #
            sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item);  sym_constrain_range_for_size_default = None

                 ge_3: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_3'");  ge_3 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 5)" = item <= 5
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

                 zeros: "f32[u0, 5]" = torch.ops.aten.zeros.default([item, 5], device = device(type='cpu'), pin_memory = False);  item = None
            return (zeros,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='zeros'), target=None)])
Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}

constrain_as_value_example

注意

标签:torch.escape-hatchtorch.dynamic-value

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch


class ConstrainAsValueExample(torch.nn.Module):
    """
    If the value is not known at tracing time, you can provide hint so that we
    can trace further. Please look at torch._check and torch._check_is_size APIs.
    torch._check is used for values that don't need to be used for constructing
    tensor.
    """

    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a <= 5)

        if a < 6:
            return y.sin()
        return y.cos()


example_args = (torch.tensor(4), torch.randn(5, 5))
tags = {
    "torch.dynamic-value",
    "torch.escape-hatch",
}
model = ConstrainAsValueExample()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "i64[]", y: "f32[5, 5]"):
                 item: "Sym(u0)" = torch.ops.aten.item.default(x);  x = None
            ge_1: "Sym(u0 >= 0)" = item >= 0
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'");  ge_1 = _assert_scalar_default = None
            le_1: "Sym(u0 <= 5)" = item <= 5;  item = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 5 on node 'le_1'");  le_1 = _assert_scalar_default_1 = None

                 sin: "f32[5, 5]" = torch.ops.aten.sin.default(y);  y = None
            return (sin,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='sin'), target=None)])
Range constraints: {u0: VR[0, 5], u1: VR[0, 5], u2: VR[0, 5]}

decorator

注意

标签

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import functools

import torch

def test_decorator(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs) + 1

    return wrapper

class Decorator(torch.nn.Module):
    """
    Decorators calls are inlined into the exported function during tracing.
    """

    @test_decorator
    def forward(self, x, y):
        return x + y

example_args = (torch.randn(3, 2), torch.randn(3, 2))
model = Decorator()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]", y: "f32[3, 2]"):
                 add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, y);  x = y = None

                 add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1);  add = None
            return (add_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)])
Range constraints: {}

dictionary

注意

标签:python.data-structure

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class Dictionary(torch.nn.Module):
    """
    Dictionary structures are inlined and flattened along tracing.
    """

    def forward(self, x, y):
        elements = {}
        elements["x2"] = x * x
        y = y * elements["x2"]
        return {"y": y}

example_args = (torch.randn(3, 2), torch.tensor(4))
tags = {"python.data-structure"}
model = Dictionary()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]", y: "i64[]"):
                 mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x);  x = None

                 mul_1: "f32[3, 2]" = torch.ops.aten.mul.Tensor(y, mul);  y = mul = None
            return (mul_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_1'), target=None)])
Range constraints: {}

dynamic_shape_assert

注意

标签:python.assert

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class DynamicShapeAssert(torch.nn.Module):
    """
    A basic usage of python assertion.
    """

    def forward(self, x):
        # assertion with error message
        assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
        # assertion without error message
        assert x.shape[0] > 1
        return x

example_args = (torch.randn(3, 2),)
tags = {"python.assert"}
model = DynamicShapeAssert()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
            return (x,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='x'), target=None)])
Range constraints: {}

dynamic_shape_constructor

注意

标签:torch.dynamic-shape

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class DynamicShapeConstructor(torch.nn.Module):
    """
    Tensor constructors should be captured with dynamic shape inputs rather
    than being baked in with static shape.
    """

    def forward(self, x):
        return torch.zeros(x.shape[0] * 2)

example_args = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeConstructor()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 zeros: "f32[6]" = torch.ops.aten.zeros.default([6], device = device(type='cpu'), pin_memory = False)
            return (zeros,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='zeros'), target=None)])
Range constraints: {}

dynamic_shape_if_guard

注意

标签:torch.dynamic-shapepython.control-flow

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class DynamicShapeIfGuard(torch.nn.Module):
    """
    `if` statement with backed dynamic shape predicate will be specialized into
    one particular branch and generate a guard. However, export will fail if the
    the dimension is marked as dynamic shape from higher level API.
    """

    def forward(self, x):
        if x.shape[0] == 3:
            return x.cos()

        return x.sin()

example_args = (torch.randn(3, 2, 2),)
tags = {"torch.dynamic-shape", "python.control-flow"}
model = DynamicShapeIfGuard()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2, 2]"):
                 cos: "f32[3, 2, 2]" = torch.ops.aten.cos.default(x);  x = None
            return (cos,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}

dynamic_shape_map

注意

标签:torch.dynamic-shapetorch.map

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

from functorch.experimental.control_flow import map

class DynamicShapeMap(torch.nn.Module):
    """
    functorch map() maps a function over the first tensor dimension.
    """

    def forward(self, xs, y):
        def body(x, y):
            return x + y

        return map(body, xs, y)

example_args = (torch.randn(3, 2), torch.randn(2))
tags = {"torch.dynamic-shape", "torch.map"}
model = DynamicShapeMap()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, xs: "f32[3, 2]", y: "f32[2]"):
                 body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y]);  body_graph_0 = xs = y = None
            getitem: "f32[3, 2]" = map_impl[0];  map_impl = None
            return (getitem,)

        class body_graph_0(torch.nn.Module):
            def forward(self, xs: "f32[2]", y: "f32[2]"):
                         add: "f32[2]" = torch.ops.aten.add.Tensor(xs, y);  xs = y = None
                return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='xs'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

dynamic_shape_slicing

注意

标签:torch.dynamic-shape

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class DynamicShapeSlicing(torch.nn.Module):
    """
    Slices with dynamic shape arguments should be captured into the graph
    rather than being baked in.
    """

    def forward(self, x):
        return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]

example_args = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeSlicing()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 slice_1: "f32[1, 2]" = torch.ops.aten.slice.Tensor(x, 0, 0, 1);  x = None
            slice_2: "f32[1, 1]" = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 9223372036854775807, 2);  slice_1 = None
            return (slice_2,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='slice_2'), target=None)])
Range constraints: {}

dynamic_shape_view

注意

标签:torch.dynamic-shape

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class DynamicShapeView(torch.nn.Module):
    """
    Dynamic shapes should be propagated to view arguments instead of being
    baked into the exported graph.
    """

    def forward(self, x):
        new_x_shape = x.size()[:-1] + (2, 5)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1)

example_args = (torch.randn(10, 10),)
tags = {"torch.dynamic-shape"}
model = DynamicShapeView()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 10]"):
                 view: "f32[10, 2, 5]" = torch.ops.aten.view.default(x, [10, 2, 5]);  x = None

                 permute: "f32[10, 5, 2]" = torch.ops.aten.permute.default(view, [0, 2, 1]);  view = None
            return (permute,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='permute'), target=None)])
Range constraints: {}

fn_with_kwargs

注意

标签:python.data-structure

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class FnWithKwargs(torch.nn.Module):
    """
    Keyword arguments are not supported at the moment.
    """

    def forward(self, pos0, tuple0, *myargs, mykw0, **mykwargs):
        out = pos0
        for arg in tuple0:
            out = out * arg
        for arg in myargs:
            out = out * arg
        out = out * mykw0
        out = out * mykwargs["input0"] * mykwargs["input1"]
        return out

example_args = (
    torch.randn(4),
    (torch.randn(4), torch.randn(4)),
    *[torch.randn(4), torch.randn(4)]
)
example_kwargs = {
    "mykw0": torch.randn(4),
    "input0": torch.randn(4),
    "input1": torch.randn(4),
}
tags = {"python.data-structure"}
model = FnWithKwargs()


torch.export.export(model, example_args, example_kwargs)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, pos0: "f32[4]", tuple0_0: "f32[4]", tuple0_1: "f32[4]", myargs_0: "f32[4]", myargs_1: "f32[4]", mykw0: "f32[4]", input0: "f32[4]", input1: "f32[4]"):
                 mul: "f32[4]" = torch.ops.aten.mul.Tensor(pos0, tuple0_0);  pos0 = tuple0_0 = None
            mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, tuple0_1);  mul = tuple0_1 = None

                 mul_2: "f32[4]" = torch.ops.aten.mul.Tensor(mul_1, myargs_0);  mul_1 = myargs_0 = None
            mul_3: "f32[4]" = torch.ops.aten.mul.Tensor(mul_2, myargs_1);  mul_2 = myargs_1 = None

                 mul_4: "f32[4]" = torch.ops.aten.mul.Tensor(mul_3, mykw0);  mul_3 = mykw0 = None

                 mul_5: "f32[4]" = torch.ops.aten.mul.Tensor(mul_4, input0);  mul_4 = input0 = None
            mul_6: "f32[4]" = torch.ops.aten.mul.Tensor(mul_5, input1);  mul_5 = input1 = None
            return (mul_6,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='pos0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='tuple0_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='tuple0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='myargs_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='myargs_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='mykw0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='input1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='mul_6'), target=None)])
Range constraints: {}

list_contains

注意

标签:torch.dynamic-shapepython.data-structurepython.assert

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class ListContains(torch.nn.Module):
    """
    List containment relation can be checked on a dynamic shape or constants.
    """

    def forward(self, x):
        assert x.size(-1) in [6, 2]
        assert x.size(0) not in [4, 5, 6]
        assert "monkey" not in ["cow", "pig"]
        return x + x

example_args = (torch.randn(3, 2),)
tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"}
model = ListContains()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, x);  x = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

list_unpack

注意

标签:python.control-flowpython.data-structure

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
from typing import List

import torch

class ListUnpack(torch.nn.Module):
    """
    Lists are treated as static construct, therefore unpacking should be
    erased after tracing.
    """

    def forward(self, args: List[torch.Tensor]):
        """
        Lists are treated as static construct, therefore unpacking should be
        erased after tracing.
        """
        x, *y = args
        return x + y[0]

example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],)
tags = {"python.control-flow", "python.data-structure"}
model = ListUnpack()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, args_0: "f32[3, 2]", args_1: "i64[]", args_2: "i64[]"):
                 add: "f32[3, 2]" = torch.ops.aten.add.Tensor(args_0, args_1);  args_0 = args_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_0'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='args_2'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

nested_function

注意

标签:python.closure

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class NestedFunction(torch.nn.Module):
    """
    Nested functions are traced through. Side effects on global captures
    are not supported though.
    """

    def forward(self, a, b):
        x = a + b
        z = a - b

        def closure(y):
            nonlocal x
            x += 1
            return x * y + z

        return closure(x)

example_args = (torch.randn(3, 2), torch.randn(2))
tags = {"python.closure"}
model = NestedFunction()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, a: "f32[3, 2]", b: "f32[2]"):
                 add: "f32[3, 2]" = torch.ops.aten.add.Tensor(a, b)

                 sub: "f32[3, 2]" = torch.ops.aten.sub.Tensor(a, b);  a = b = None

                 add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(add, 1);  add = None

                 mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(add_1, add_1);  add_1 = None
            add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, sub);  mul = sub = None
            return (add_2,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='a'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='b'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None)])
Range constraints: {}

null_context_manager

注意

标签:python.context-manager

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import contextlib

import torch

class NullContextManager(torch.nn.Module):
    """
    Null context manager in Python will be traced out.
    """

    def forward(self, x):
        """
        Null context manager in Python will be traced out.
        """
        ctx = contextlib.nullcontext()
        with ctx:
            return x.sin() + x.cos()

example_args = (torch.randn(3, 2),)
tags = {"python.context-manager"}
model = NullContextManager()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 sin: "f32[3, 2]" = torch.ops.aten.sin.default(x)
            cos: "f32[3, 2]" = torch.ops.aten.cos.default(x);  x = None
            add: "f32[3, 2]" = torch.ops.aten.add.Tensor(sin, cos);  sin = cos = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

pytree_flatten

注意

标签

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

from torch.utils import _pytree as pytree

class PytreeFlatten(torch.nn.Module):
    """
    Pytree from PyTorch can be captured by TorchDynamo.
    """

    def forward(self, x):
        y, spec = pytree.tree_flatten(x)
        return y[0] + 1

example_args = ({1: torch.randn(3, 2), 2: torch.randn(3, 2)},),
model = PytreeFlatten()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x_0_1: "f32[3, 2]", x_0_2: "f32[3, 2]"):
                 add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x_0_1, 1);  x_0_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x_0_1'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x_0_2'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

scalar_output

注意

标签:torch.dynamic-shape

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

from torch.export import Dim

x = torch.randn(3, 2)
dim1_x = Dim("dim1_x")

class ScalarOutput(torch.nn.Module):
    """
    Returning scalar values from the graph is supported, in addition to Tensor
    outputs. Symbolic shapes are captured and rank is specialized.
    """
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        return x.shape[1] + 1

example_args = (x,)
tags = {"torch.dynamic-shape"}
dynamic_shapes = {"x": {1: dim1_x}}
model = ScalarOutput()


torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, s0]"):
             #
            sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 1);  x = None

                 add: "Sym(s0 + 1)" = sym_size_int_1 + 1;  sym_size_int_1 = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=SymIntArgument(name='add'), target=None)])
Range constraints: {s0: VR[0, int_oo]}

specialized_attribute

注意

标签

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
from enum import Enum

import torch

class Animal(Enum):
    COW = "moo"

class SpecializedAttribute(torch.nn.Module):
    """
    Model attributes are specialized.
    """

    def __init__(self) -> None:
        super().__init__()
        self.a = "moo"
        self.b = 4

    def forward(self, x):
        if self.a == Animal.COW.value:
            return x * x + self.b
        else:
            raise ValueError("bad")

example_args = (torch.randn(3, 2),)
model = SpecializedAttribute()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, x);  x = None
            add: "f32[3, 2]" = torch.ops.aten.add.Tensor(mul, 4);  mul = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

static_for_loop

注意

标签:python.control-flow

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class StaticForLoop(torch.nn.Module):
    """
    A for loop with constant number of iterations should be unrolled in the exported graph.
    """

    def forward(self, x):
        ret = []
        for i in range(10):  # constant
            ret.append(i + x)
        return ret

example_args = (torch.randn(3, 2),)
tags = {"python.control-flow"}
model = StaticForLoop()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 0)
            add_1: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 1)
            add_2: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 2)
            add_3: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 3)
            add_4: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4)
            add_5: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 5)
            add_6: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 6)
            add_7: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 7)
            add_8: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 8)
            add_9: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 9);  x = None
            return (add, add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_2'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_4'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_5'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_6'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_7'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_8'), target=None), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_9'), target=None)])
Range constraints: {}

static_if

注意

标签:python.control-flow

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class StaticIf(torch.nn.Module):
    """
    `if` statement with static predicate value should be traced through with the
    taken branch.
    """

    def forward(self, x):
        if len(x.shape) == 3:
            return x + torch.ones(1, 1, 1)

        return x

example_args = (torch.randn(3, 2, 2),)
tags = {"python.control-flow"}
model = StaticIf()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2, 2]"):
                 ones: "f32[1, 1, 1]" = torch.ops.aten.ones.default([1, 1, 1], device = device(type='cpu'), pin_memory = False)
            add: "f32[3, 2, 2]" = torch.ops.aten.add.Tensor(x, ones);  x = ones = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

tensor_setattr

注意

标签:python.builtin

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch


class TensorSetattr(torch.nn.Module):
    """
    setattr() call onto tensors is not supported.
    """
    def forward(self, x, attr):
        setattr(x, attr, torch.randn(3, 2))
        return x + 4

example_args = (torch.randn(3, 2), "attr")
tags = {"python.builtin"}
model = TensorSetattr()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]", attr):
                 add: "f32[3, 2]" = torch.ops.aten.add.Tensor(x, 4);  x = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=ConstantArgument(name='attr', value='attr'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

type_reflection_method

注意

标签:python.builtin

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch

class A:
    @classmethod
    def func(cls, x):
        return 1 + x

class TypeReflectionMethod(torch.nn.Module):
    """
    type() calls on custom objects followed by attribute accesses are not allowed
    due to its overly dynamic nature.
    """

    def forward(self, x):
        a = A()
        return type(a).func(x)


example_args = (torch.randn(3, 4),)
tags = {"python.builtin"}
model = TypeReflectionMethod()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 4]"):
                 add: "f32[3, 4]" = torch.ops.aten.add.Tensor(x, 1);  x = None
            return (add,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}

user_input_mutation

注意

标签:torch.mutation

支持级别:SUPPORTED

原始源代码

# mypy: allow-untyped-defs
import torch


class UserInputMutation(torch.nn.Module):
    """
    Directly mutate user input in forward
    """

    def forward(self, x):
        x.mul_(2)
        return x.cos()


example_args = (torch.randn(3, 2),)
tags = {"torch.mutation"}
model = UserInputMutation()


torch.export.export(model, example_args)

结果

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 2]"):
                 mul: "f32[3, 2]" = torch.ops.aten.mul.Tensor(x, 2);  x = None

                 cos: "f32[3, 2]" = torch.ops.aten.cos.default(mul)
            return (mul, cos)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_INPUT_MUTATION: 6>, arg=TensorArgument(name='mul'), target='x'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='cos'), target=None)])
Range constraints: {}

尚不支持

dynamic_shape_round

注意

标签:torch.dynamic-shapepython.builtin

支持级别:NOT_SUPPORTED_YET

原始源代码

# mypy: allow-untyped-defs
import torch

from torch._export.db.case import SupportLevel
from torch.export import Dim

class DynamicShapeRound(torch.nn.Module):
    """
    Calling round on dynamic shapes is not supported.
    """

    def forward(self, x):
        return x[: round(x.shape[0] / 2)]

x = torch.randn(3, 2)
dim0_x = Dim("dim0_x")
example_args = (x,)
tags = {"torch.dynamic-shape", "python.builtin"}
support_level = SupportLevel.NOT_SUPPORTED_YET
dynamic_shapes = {"x": {0: dim0_x}}
model = DynamicShapeRound()


torch.export.export(model, example_args, dynamic_shapes=dynamic_shapes)

结果

Unsupported: Constraints violated (dim0_x)! For more information, run with TORCH_LOGS="+dynamic".

model_attr_mutation

注意

标签:python.object-model

支持级别:NOT_SUPPORTED_YET

原始源代码

# mypy: allow-untyped-defs
import torch
from torch._export.db.case import SupportLevel


class ModelAttrMutation(torch.nn.Module):
    """
    Attribute mutation is not supported.
    """

    def __init__(self) -> None:
        super().__init__()
        self.attr_list = [torch.randn(3, 2), torch.randn(3, 2)]

    def recreate_list(self):
        return [torch.zeros(3, 2), torch.zeros(3, 2)]

    def forward(self, x):
        self.attr_list = self.recreate_list()
        return x.sum() + self.attr_list[0].sum()


example_args = (torch.randn(3, 2),)
tags = {"python.object-model"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = ModelAttrMutation()


torch.export.export(model, example_args)

结果

AssertionError: Mutating module attribute attr_list during export.

optional_input

注意

标签:python.object-model

支持级别:NOT_SUPPORTED_YET

原始源代码

# mypy: allow-untyped-defs
import torch
from torch._export.db.case import SupportLevel


class OptionalInput(torch.nn.Module):
    """
    Tracing through optional input is not supported yet
    """

    def forward(self, x, y=torch.randn(2, 3)):
        if y is not None:
            return x + y
        return x


example_args = (torch.randn(2, 3),)
tags = {"python.object-model"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = OptionalInput()


torch.export.export(model, example_args)

结果

Unsupported: Tracing through optional input is not supported yet

unsupported_operator

注意

标签:torch.operator

支持级别:NOT_SUPPORTED_YET

原始源代码

# mypy: allow-untyped-defs
import torch
from torch._export.db.case import SupportLevel


class TorchSymMin(torch.nn.Module):
    """
    torch.sym_min operator is not supported in export.
    """

    def forward(self, x):
        return x.sum() + torch.sym_min(x.size(0), 100)


example_args = (torch.randn(3, 2),)
tags = {"torch.operator"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = TorchSymMin()


torch.export.export(model, example_args)

结果

Unsupported: torch.* op returned non-Tensor int call_function <function sym_min at 0x7f98c3497040>

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源