setup_torch_profiler¶
- torchtune.training.setup_torch_profiler(enabled: bool = False, cpu: bool = True, cuda: bool = True, profile_memory: bool = False, with_stack: bool = False, record_shapes: bool = True, with_flops: bool = False, wait_steps: Optional[int] = None, warmup_steps: Optional[int] = None, active_steps: Optional[int] = None, num_cycles: Optional[int] = None, output_dir: Optional[str] = None) Tuple[profile, DictConfig] [源代码]¶
设置
profile
并返回带有后设置更新的分析器配置。分析器配置可以在配置中的
profiler
键下提供,并具有以下布局profiler: _component_: torchtune.training.setup_torch_profiler enabled: bool # Output directory of trace artifacts output_dir: str # torch.profiler.ProfilerActivity types to trace cpu: bool cuda: bool # Trace options profile_memory: bool with_stack: bool record_shapes: bool with_flops: bool # torch.profiler.schedule args wait_steps: int warmup_steps: int active_steps: int num_cycles: int
分析器计划相对于优化器步骤进行更新(例如,如果
gradient_accumulation = 2
,则分析器将在每 2 个批次后进行更新)。如果配置中缺少选项,将选择合理的默认值。
如果未指定任何活动,分析器将默认为 CPU + CUDA。
如果未指定计划,分析器将默认为
DEFAULT_SCHEDULE
。某些选项将被覆盖(
with_stack
和record_shapes
),具体取决于其他选项的要求(例如,profile_memory
需要with_stack
和record_shapes
)。
注意
启用分析器会导致训练速度下降。
设置
profile_memory: True
将生成大型跟踪文件。分析器调度程序与上下文相关。在每个批次迭代中调用
profiler.step()
,但在梯度累积范围之外,将step
分析器每个正向/反向步骤。在每个批次迭代中调用profiler.step()
,但在梯度累积范围之内,将step
分析器每个优化器更新步骤,使得每个step
包含多个正向/反向传递。
- 参数:
enabled (bool) – 启用 PyTorch 分析器。默认值为 False。
cpu (bool) – 启用 CPU 分析。默认值为 True。
cuda (bool) – 启用 CUDA 分析。默认值为 True。
profile_memory (bool) – 分析内存使用情况。默认值为 False。
with_stack (bool) – 分析堆栈。默认值为 False。
record_shapes (bool) – 记录形状。默认值为 True。
with_flops (bool) – 分析 FLOPs。默认值为 False。
wait_steps (Optional[int]) – 等待时间(以步骤计)。映射到
wait
的torch.profiler.schedule
的 kwarg。warmup_steps (Optional[int]) – 预热时间(以步骤计)。映射到
warmup
的torch.profiler.schedule
的 kwarg。active_steps (Optional[int]) – 活动时间(以步骤计)。映射到
active
的torch.profiler.schedule
的 kwarg。num_cycles (Optional[int]) – 分析周期的数量。映射到
repeat
的torch.profiler.schedule
的 kwarg。output_dir (Optional[str]) – 跟踪文件输出路径。
- 返回值:
Tuple[torch.profiler.profile, DictConfig]