get_dtype¶
- torchtune.training.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype [源代码]¶
获取与给定精度字符串对应的 torch.dtype。如果未传入字符串,则默认为 torch.float32。
注意
如果使用 CUDA 设备请求 bf16 精度,我们将验证该设备是否确实支持 bf16 内核。如果不是,则会引发
RuntimeError
。- 参数:
dtype (Optional[str]) – 精度 dtype。默认:
None
,在这种情况下我们默认为 torch.float32device (Optional[torch.device]) – 用于训练的设备。仅支持 CUDA 和 CPU 设备。如果传入 CUDA 设备,则会进行额外的检查以确保该设备支持请求的精度。默认:
None
,在这种情况下,假设为 CUDA 设备。
- 引发:
ValueError – 如果库不支持精度
RuntimeError – 如果请求 bf16 精度,但此硬件上不可用。
- 返回值:
相应的 torch.dtype。
- 返回类型: