get_dtype¶
- torchtune.training.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype [源代码]¶
获取与给定精度字符串对应的 torch.dtype。如果未传递字符串,我们将默认使用 torch.float32。
注意
如果请求使用 CUDA 设备的 bf16 精度,我们将验证设备是否确实支持 bf16 内核。如果不支持,则会引发
RuntimeError
。- 参数:
dtype (Optional[str]) – 精度 dtype。默认值:
None
,在这种情况下,我们默认使用 torch.float32device (Optional[torch.device]) – 训练中使用的设备。仅支持 CUDA 和 CPU 设备。如果传入 CUDA 设备,则会进行额外的检查以确保设备支持所请求的精度。默认值:
None
,在这种情况下,假定为 CUDA 设备。
- 引发:
ValueError – 如果库不支持该精度
RuntimeError – 如果请求 bf16 精度,但此硬件上不可用。
- 返回:
对应的 torch.dtype。
- 返回类型: