set_default_dtype¶
- torchtune.training.set_default_dtype(dtype: dtype) Generator[None, None, None] [源代码]¶
用于设置 torch 默认数据类型的上下文管理器。
- 参数:
dtype (torch.dpython:type) – 上下文管理器中所需的默认数据类型。
- 返回值:
用于设置默认数据类型的上下文管理器。
- 返回类型:
ContextManager
示例
>>> with set_default_dtype(torch.bfloat16): >>> x = torch.tensor([1, 2, 3]) >>> x.dtype torch.bfloat16