使用 while_loop
优化内存利用率¶
while_loop
¶
while_loop
替换纯 Python while
循环,PyTorch 通过 torch._higher_order_ops.while_loop 支持 while_loop
。PyTorch/XLA 通过 XLA::While
为 torch._higher_order_ops.while_loop
提供实验性的 XLA 后端支持。
用法:¶
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
result = while_loop(cond_fn, body_fn, init)
cond_fn
:用户定义的条件函数。body_fn
:用户定义的循环体函数。init
:初始值(元组或列表)。
使用 while_loop
的简单示例:¶
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> def cond_fn(iteri, x):
... return iteri > 0
...
>>> def body_fn(iteri, x):
... return iteri - 1, torch.add(x, 1)
...
>>> init_val = torch.tensor(3, device=device)
>>> iteri = torch.tensor(10, device=device)
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
>>> res
FunctionalTensor(lvl=0, value=\
tensor(13, device='xla:0'))
对照组测试用例¶
为了更好地比较 纯 Python while 循环
和 while_loop
之间的差异,这里有一个名为纯 Python while
循环的测试用例,其逻辑类似:累加 1 十次
使用纯 Python while
循环的对照组示例¶
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = xm.xla_device()
>>>
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
>>>
>>> while iteri > 0:
... init_val = init_val + 1
... iteri -= 1
...
>>> init_val
tensor(51, device='xla:0')
PyTorch/XLA 将在 2.4 版本中包含 while_loop
支持并提供测试用例,对 fori_loop
的支持将在 2.4 版本之后添加。对于 while_loop
,目前我们只应强制定义具有相同 输入
和 输出(返回参数)
形状的 body_fn