setup_torch_profiler¶
- torchtune.training.setup_torch_profiler(enabled: bool = False, cpu: bool = True, cuda: bool = True, xpu: bool = True, profile_memory: bool = False, with_stack: bool = False, record_shapes: bool = True, with_flops: bool = False, wait_steps: Optional[int] = None, warmup_steps: Optional[int] = None, active_steps: Optional[int] = None, num_cycles: Optional[int] = None, output_dir: Optional[str] = None) Tuple[profile, DictConfig] [源文件]¶
设置
profile
并返回包含设置后更新的性能分析器配置。性能分析器配置可以在配置中以
profiler
键的形式提供,其布局如下profiler: _component_: torchtune.training.setup_torch_profiler enabled: bool # Output directory of trace artifacts output_dir: str # torch.profiler.ProfilerActivity types to trace cpu: bool cuda: bool # Trace options profile_memory: bool with_stack: bool record_shapes: bool with_flops: bool # torch.profiler.schedule args wait_steps: int warmup_steps: int active_steps: int num_cycles: int
性能分析器计划根据优化器步骤进行更新(例如,如果
gradient_accumulation = 2
,则性能分析器每 2 个批次进行一次步骤更新)。如果配置缺少选项,将选择合理的默认值
如果未指定活动,性能分析器将默认为 CPU + CUDA
如果未指定计划,性能分析器将默认为
DEFAULT_SCHEDULE
某些选项(
with_stack
和record_shapes
)将根据其他选项的要求被覆盖(例如,profile_memory
需要with_stack
和record_shapes
)。
注意
启用性能分析器将导致训练速度降低。
设置
profile_memory: True
将生成大型跟踪文件。性能分析器计划是上下文相关的。在每次批次迭代时但在梯度累积范围之外调用
profiler.step()
将使性能分析器在每次前向/后向步骤进行step
。在每次批次迭代时但在梯度累积范围之内调用profiler.step()
将使性能分析器在每次优化器更新步骤进行step
,以便每个step
包含多个前向/后向传递。
- 参数:
enabled (bool) – 启用 pytorch 性能分析器。默认为 False。
cpu (bool) – 启用 CPU 性能分析。默认为 True。
cuda (bool) – 启用 CUDA 性能分析。默认为 True。
xpu (bool) – 启用 XPU 性能分析。默认为 True。
profile_memory (bool) – 分析内存使用。默认为 False。
with_stack (bool) – 分析堆栈。默认为 False。
record_shapes (bool) – 记录形状。默认为 True。
with_flops (bool) – 分析浮点运算 (flops)。默认为 False。
wait_steps (Optional[int]) – 等待步数。映射到
torch.profiler.schedule
的wait
参数。warmup_steps (Optional[int]) – 热身步数。映射到
torch.profiler.schedule
的warmup
参数。active_steps (Optional[int]) – 活动步数。映射到
torch.profiler.schedule
的active
参数。num_cycles (Optional[int]) – 性能分析循环次数。映射到
torch.profiler.schedule
的repeat
参数。output_dir (Optional[str]) – 跟踪文件输出路径。
- 返回:
Tuple[torch.profiler.profile, DictConfig]