get_memory_stats¶
- torchtune.training.get_memory_stats(device: device, reset_stats: bool = True) dict [source]¶
计算传入设备的内存摘要。如果
reset_stats
为True
,则还会重置 CUDA 的峰值内存跟踪。这对于获取有关峰值内存相对使用情况(例如,模型初始化期间、前向传播期间等的峰值内存)的数据,并优化训练各个部分的内存非常有用。- 参数:
device (torch.device) – 要获取内存摘要的设备。仅支持 CUDA 设备。
reset_stats (bool) – 是否重置 CUDA 的峰值内存跟踪。
- 返回:
一个字典,包含峰值内存活动量、峰值内存分配量和峰值内存预留量。此字典对于记录内存统计信息非常有用。
- 返回类型:
- 引发:
ValueError – 如果传入的设备不是 CUDA。