使用 scan
和 scan_layers
指南¶
本指南介绍如何在 PyTorch/XLA 中使用 scan
和 scan_layers
。
何时应该使用它们¶
如果你的模型具有许多同构(形状相同,逻辑相同)层(例如 LLM),你应该考虑使用 ``scan_layers` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py>`_。这些模型编译速度可能很慢。scan_layers
是同构层循环的直接替代品,例如一批解码器层。scan_layers
跟踪第一层并为所有后续层重用编译结果,从而显著减少模型编译时间。
另一方面,``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ 是一个较低级别的高阶操作,模仿 ``jax.lax.scan` <https://jax.net.cn/en/latest/_autosummary/jax.lax.scan.html>`_。它的主要目的是帮助在底层实现 scan_layers
。但是,如果你想编写某种循环逻辑,其中循环本身在编译器中具有第一类表示(具体来说,是 XLA While
操作),你可能会发现它很有用。
scan_layers
示例¶
通常,transformer 模型通过一系列同构解码器层传递输入嵌入,如下所示
def run_decoder_layers(self, hidden_states):
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states)
return hidden_states
当此函数被降低到 HLO 图时,for 循环被展开为操作的扁平列表,从而导致编译时间过长。为了减少编译时间,你可以将 for 循环替换为对 scan_layers
的调用,如 ``decoder_with_scan.py` </examples/scan/decoder_with_scan.py>`_ 中所示
def run_decoder_layers(self, hidden_states):
from torch_xla.experimental.scan_layers import scan_layers
return scan_layers(self.layers, hidden_states)
你可以通过从 pytorch/xla
源代码检出的根目录运行以下命令来训练此解码器模型。
python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan
scan
示例¶
``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ 接受一个组合函数,并在张量的引导维度上应用该函数,同时携带状态
def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
) -> tuple[Carry, Y]:
...
你可以使用它来有效地循环遍历张量的引导维度。如果 xs
是单个张量,则此函数大致等于以下 Python 代码
def scan(fn, init, xs):
ys = []
carry = init
for i in len(range(xs.size(0))):
carry, y = fn(carry, xs[i])
ys.append(y)
return carry, torch.stack(ys, dim=0)
在底层,scan
的实现效率更高,方法是将循环降低到 XLA While
操作。这确保了 XLA 只编译循环的一次迭代。
``scan_examples.py` </examples/scan/scan_examples.py>`_ 包含一些示例代码,展示了如何使用 scan
。在该文件中,scan_example_cumsum
使用 scan
来实现累积和。scan_example_pytree
演示了如何将 PyTrees 传递给 scan
。
你可以使用以下命令运行示例
python3 examples/scan/scan_examples.py
输出应如下所示
Running example: scan_example_cumsum
Final sum: tensor([6.], device='xla:0')
History of sums tensor([[1.],
[3.],
[6.]], device='xla:0')
Running example: scan_example_pytree
Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')}
Means over time: tensor([[1.0000],
[1.5000],
[2.0000],
[2.5000],
[3.0000]], device='xla:0')
局限性¶
编译时间实验¶
为了演示编译时间节省,我们将在一个 TPU 芯片上训练一个带有许多层的简单解码器,分别使用 for 循环和 scan_layers
。
运行 for 循环实现
❯ python3 examples/train_decoder_only_base.py \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics
...
Metric: CompileTime
TotalSamples: 3
Accumulator: 02m57s694ms418.595us
ValueRate: 02s112ms586.097us / second
Rate: 0.054285 / second
Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us;
99%=01m03s028ms571.841us
运行
scan_layers
实现
❯ python3 examples/train_decoder_only_base.py \
scan.decoder_with_scan.DecoderWithScan \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics
...
Metric: CompileTime
TotalSamples: 3
Accumulator: 29s996ms941.409us
ValueRate: 02s529ms591.388us / second
Rate: 0.158152 / second
Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us;
99%=18s995ms301.667us
我们可以看到,通过切换到 scan_layers
,最大编译时间从 1m03s
降至 19s
。
参考资料¶
有关 scan
和 scan_layers
本身的设计,请参阅 https://github.com/pytorch/xla/issues/7253。
有关如何使用它们的详细信息,请参阅 ``scan` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py>`_ 和 ``scan_layers` <https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py>`_ 的函数文档注释。