快捷方式

get_dtype

torchtune.training.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype[源代码]

获取与给定精度字符串对应的 torch.dtype。如果未传入字符串,则默认为 torch.float32。

注意

如果使用 CUDA 设备请求 bf16 精度,我们将验证该设备是否确实支持 bf16 内核。如果不是,则会引发 RuntimeError

参数:
  • dtype (Optional[str]) – 精度 dtype。默认:None,在这种情况下我们默认为 torch.float32

  • device (Optional[torch.device]) – 用于训练的设备。仅支持 CUDA 和 CPU 设备。如果传入 CUDA 设备,则会进行额外的检查以确保该设备支持请求的精度。默认:None,在这种情况下,假设为 CUDA 设备。

引发:
  • ValueError – 如果库不支持精度

  • RuntimeError – 如果请求 bf16 精度,但此硬件上不可用。

返回值:

相应的 torch.dtype。

返回类型:

torch.dtype

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源