get_dtype¶
- torchtune.training.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype [源代码]¶
获取对应于给定精度字符串的 torch.dtype。如果没有传递字符串,默认为 torch.float32。
注意
如果使用 CUDA 设备请求 bf16 精度,我们将验证该设备是否确实支持 bf16 内核。如果不支持,则会引发
RuntimeError
。- 参数:
dtype (Optional[str]) – 精度 dtype。默认值:
None
,此时默认为 torch.float32device (Optional[torch.device]) – 训练中使用的设备。仅支持 CUDA 和 CPU 设备。如果传入 CUDA 设备,会进行额外检查以确保设备支持请求的精度。默认值:
None
,此时假定为 CUDA 设备。
- 抛出:
ValueError – 如果库不支持该精度
RuntimeError – 如果请求了 bf16 精度但在当前硬件上不可用。
- 返回:
对应的 torch.dtype。
- 返回类型: