torch.cond¶
- torch.cond(pred, true_fn, false_fn, operands)¶
有条件地应用 true_fn 或 false_fn。
警告
torch.cond 是 PyTorch 中的原型功能。它对输入和输出类型的支持有限,目前不支持训练。敬请期待 PyTorch 未来版本中更稳定的实现。有关功能分类的更多信息,请阅读:https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype
cond 是结构化控制流运算符。也就是说,它类似于 Python 中的 if 语句,但对 true_fn、false_fn 和 operands 有一些限制,这些限制使它能够使用 torch.compile 和 torch.export 进行捕获。
假设满足了对 cond 参数的约束,则 cond 等效于以下内容
def cond(pred, true_branch, false_branch, operands): if pred: return true_branch(*operands) else: return false_branch(*operands)
- 参数
pred (Union[bool, torch.Tensor]) – 布尔表达式或具有一个元素的张量,指示要应用哪个分支函数。
true_fn (Callable) – 可调用函数 (a -> b),位于正在跟踪的范围内。
false_fn (Callable) – 可调用函数 (a -> b),位于正在跟踪的范围内。真分支和假分支必须具有一致的输入和输出,这意味着输入必须相同,并且输出必须具有相同的类型和形状。
operands (Tuple of 可能嵌套的 dict/list/tuple of torch.Tensor) – true/false 函数的输入元组。
示例
def true_fn(x: torch.Tensor): return x.cos() def false_fn(x: torch.Tensor): return x.sin() return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
- 限制
条件语句(即 pred)必须满足以下约束之一
它是 torch.Tensor,只有一个元素,并且数据类型为 torch.bool
它是布尔表达式,例如 x.shape[0] > 10 或 x.dim() > 1 and x.shape[1] > 10
分支函数(即 true_fn/false_fn)必须满足以下所有约束
函数签名必须与 operands 匹配。
函数必须返回具有相同元数据的张量,例如形状、数据类型等。
函数不能对输入或全局变量进行就地修改。(注意:分支中允许对中间结果使用就地张量运算,例如 add_)
警告
时间限制
分支的**输出**必须是**单个张量**。将来会支持张量 Pytree。