快捷方式

get_dtype

torchtune.training.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype[源代码]

获取对应于给定精度字符串的 torch.dtype。如果没有传递字符串,默认为 torch.float32。

注意

如果使用 CUDA 设备请求 bf16 精度,我们将验证该设备是否确实支持 bf16 内核。如果不支持,则会引发 RuntimeError

参数:
  • dtype (Optional[str]) – 精度 dtype。默认值:None,此时默认为 torch.float32

  • device (Optional[torch.device]) – 训练中使用的设备。仅支持 CUDA 和 CPU 设备。如果传入 CUDA 设备,会进行额外检查以确保设备支持请求的精度。默认值:None,此时假定为 CUDA 设备。

抛出:
  • ValueError – 如果库不支持该精度

  • RuntimeError – 如果请求了 bf16 精度但在当前硬件上不可用。

返回:

对应的 torch.dtype。

返回类型:

torch.dtype

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得问题解答

查看资源