快捷方式

torch.cond

torch.cond(pred, true_fn, false_fn, operands=())[source]

有条件地应用 true_fnfalse_fn

警告

torch.cond 是 PyTorch 中的原型功能。它对输入和输出类型的支持有限,目前不支持训练。请期待未来 PyTorch 版本中更稳定的实现。阅读更多关于功能分类的信息: https://pytorch.ac.cn/blog/pytorch-feature-classification-changes/#prototype

cond 是结构化控制流操作符。也就是说,它类似于 Python 的 if 语句,但对 true_fnfalse_fnoperands 有限制,使其能够使用 torch.compile 和 torch.export 进行捕获。

假设满足对 cond 参数的约束,cond 等价于以下内容

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
参数
  • pred (Union[bool, torch.Tensor]) – 一个布尔表达式或一个包含一个元素的张量,指示要应用哪个分支函数。

  • true_fn (Callable) – 一个可调用函数 (a -> b),它在正在跟踪的范围内。

  • false_fn (Callable) – 一个可调用函数 (a -> b),它在正在跟踪的范围内。true 分支和 false 分支必须具有一致的输入和输出,这意味着输入必须相同,输出必须具有相同的类型和形状。

  • operands (Tuple of possibly nested dict/list/tuple of torch.Tensor) – true/false 函数的输入元组。如果 true_fn/false_fn 不需要输入,则可以为空。默认为 ()。

返回类型

Any

示例

def true_fn(x: torch.Tensor):
    return x.cos()
def false_fn(x: torch.Tensor):
    return x.sin()
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
限制
  • 条件语句(又名 pred)必须满足以下约束之一

    • 它是一个 torch.Tensor,只有一个元素,并且是 torch.bool 数据类型

    • 它是一个布尔表达式,例如 x.shape[0] > 10x.dim() > 1 and x.shape[1] > 10

  • 分支函数(又名 true_fn/false_fn)必须满足以下所有约束

    • 函数签名必须与 operands 匹配。

    • 函数必须返回具有相同元数据的张量,例如形状、数据类型等。

    • 函数不能对输入或全局变量进行原地修改。(注意:允许在分支中使用原地张量操作,例如 add_ 用于中间结果)

警告

时间限制

  • 分支的输出必须是单个张量。未来将支持张量 PyTree。

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得解答

查看资源