• 文档 >
  • 使用 while_loop 优化内存利用率
快捷方式

使用 while_loop 优化内存利用率

while_loop

while_loop 替换纯 Python while 循环,PyTorch 通过 torch._higher_order_ops.while_loop 支持 while_loop。PyTorch/XLA 通过 XLA::Whiletorch._higher_order_ops.while_loop 提供实验性的 XLA 后端支持。

用法:

import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
result = while_loop(cond_fn, body_fn, init)
  • cond_fn:用户定义的条件函数。

  • body_fn:用户定义的循环体函数。

  • init:初始值(元组或列表)。

使用 while_loop 的简单示例:

# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> def cond_fn(iteri, x):
...   return iteri > 0
...
>>> def body_fn(iteri, x):
...   return iteri - 1, torch.add(x, 1)
...
>>> init_val = torch.tensor(3, device=device)
>>> iteri = torch.tensor(10, device=device)
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
>>> res
FunctionalTensor(lvl=0, value=\
tensor(13, device='xla:0'))

对照组测试用例

为了更好地比较 Python while 循环while_loop 之间的差异,这里有一个名为纯 Python while 循环的测试用例,其逻辑类似:累加 1 十次

使用纯 Python while 循环的对照组示例

# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
>>>
>>> while iteri > 0:
...   init_val = init_val + 1
...   iteri -= 1
...
>>> init_val
tensor(51, device='xla:0')

PyTorch/XLA 将在 2.4 版本中包含 while_loop 支持并提供测试用例,对 fori_loop 的支持将在 2.4 版本之后添加。对于 while_loop,目前我们只应强制定义具有相同 输入输出(返回参数) 形状的 body_fn

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源