控制流 - Cond¶
torch.cond 是一个结构化的控制流运算符。它可用于指定类似 if-else 的控制流,从逻辑上看可以实现如下。
def cond(
pred: Union[bool, torch.Tensor],
true_fn: Callable,
false_fn: Callable,
operands: Tuple[torch.Tensor]
):
if pred:
return true_fn(*operands)
else:
return false_fn(*operands)
它的独特之处在于能够表达数据依赖的控制流:它降低到一个条件运算符 (torch.ops.higher_order.cond),该运算符保留谓词、真函数和假函数。这为编写和部署模型提供了极大的灵活性,这些模型可以根据输入或张量运算中间输出的值或形状改变模型架构。
警告
torch.cond 是 PyTorch 中的原型功能。它对输入和输出类型的支持有限,目前不支持训练。请期待 PyTorch 未来版本中更稳定的实现。阅读有关功能分类的更多信息:https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype
示例¶
下面是一个使用 cond 根据输入形状进行分支的示例
import torch
def true_fn(x: torch.Tensor):
return x.cos() + x.sin()
def false_fn(x: torch.Tensor):
return x.sin()
class DynamicShapeCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on dynamic shape predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
def true_fn(x: torch.Tensor):
return x.cos()
def false_fn(x: torch.Tensor):
return x.sin()
return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))
dyn_shape_mod = DynamicShapeCondPredicate()
我们可以急切地运行模型,并期望结果根据输入形状而变化
inp = torch.randn(3)
inp2 = torch.randn(5)
assert torch.equal(dyn_shape_mod(inp), false_fn(inp))
assert torch.equal(dyn_shape_mod(inp2), true_fn(inp2))
我们可以导出模型以进行进一步的转换和部署
inp = torch.randn(4, 3)
dim_batch = torch.export.Dim("batch", min=2)
ep = torch.export.export(DynamicShapeCondPredicate(), (inp,), {}, dynamic_shapes={"x": {0: dim_batch}})
print(ep)
这将给出如下所示的导出程序
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
gt: Sym(s0 > 4) = sym_size > 4; sym_size = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
请注意,torch.cond 被降低到 torch.ops.higher_order.cond,它的谓词变成了输入形状上的符号表达式,分支函数变成了顶层图模块的两个子图属性。
以下另一个示例展示了如何表达数据依赖的控制流
class DataDependentCondPredicate(torch.nn.Module):
"""
A basic usage of cond based on data dependent predicate.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cond(x.sum() > 4.0, true_fn, false_fn, (x,))
导出后得到的导出程序
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sum_1: f32[] = torch.ops.aten.sum.default(arg0_1)
gt: b8[] = torch.ops.aten.gt.Scalar(sum_1, 4.0); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
conditional: f32[s0, 3] = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
return (conditional,)
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
cos: f32[s0, 3] = torch.ops.aten.cos.default(arg0_1)
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: f32[s0, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None
return add
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 3]):
sin: f32[s0, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin
torch.ops.higher_order.cond 的不变式¶
对于 torch.ops.higher_order.cond 有几个有用的不变式
- 对于谓词
谓词的动态性得以保留(例如上面示例中的 gt)
如果用户程序中的谓词是常量(例如 Python 布尔常量),则运算符的 pred 将是常量。
- 对于分支
输入和输出签名将是扁平化的元组。
它们是 torch.fx.GraphModule。
原始函数中的闭包变成了显式输入。没有闭包。
不允许对输入或全局变量进行变异。
- 对于操作数
它也将是扁平化的元组。
用户程序中 torch.cond 的嵌套变成了嵌套的图模块。
API 参考¶
- torch._higher_order_ops.cond.cond(pred, true_fn, false_fn, operands)¶
有条件地应用 true_fn 或 false_fn。
警告
torch.cond 是 PyTorch 中的原型功能。它对输入和输出类型的支持有限,目前不支持训练。请期待 PyTorch 未来版本中更稳定的实现。阅读有关功能分类的更多信息:https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype
cond 是结构化的控制流运算符。也就是说,它类似于 Python if 语句,但对 true_fn、false_fn 和 operands 施加了限制,使其能够通过 torch.compile 和 torch.export 进行捕获。
假设满足对 cond 参数的约束,则 cond 等效于以下内容
def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands)
- 参数
pred (Union[bool, torch.Tensor]) – 布尔表达式或具有一个元素的张量,指示要应用哪个分支函数。
true_fn (Callable) – 可调用函数 (a -> b),该函数位于正在追踪的范围之内。
false_fn (Callable) – 可调用函数 (a -> b),该函数位于正在追踪的范围之内。真分支和假分支必须具有一致的输入和输出,这意味着输入必须相同,而输出必须具有相同的类型和形状。
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – 传递给真/假函数的输入元组。
示例
def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
- 限制
条件语句(也称为 pred)必须满足以下约束之一
它是一个具有一个元素且 torch.bool 数据类型的 torch.Tensor
它是一个布尔表达式,例如 x.shape[0] > 10 或 x.dim() > 1 and x.shape[1] > 10
分支函数(也称为 true_fn/false_fn)必须满足以下所有约束
函数签名必须与操作数匹配。
函数必须返回具有相同元数据的张量,例如形状、数据类型等。
函数不能对输入或全局变量进行就地变异。(注意:允许在分支中对中间结果进行就地张量运算,例如 add_)
警告
时间限制
分支的输出必须是单个张量。张量 Pytree 将在将来得到支持。