set_default_dtype¶
- torchtune.training.set_default_dtype(dtype: dtype) Generator[None, None, None] [源代码]¶
设置 torch 默认 dtype 的上下文管理器。
- 参数:
dtype (torch.dtype) – 在上下文管理器中期望的默认 dtype。
- 返回:
用于设置默认 dtype 的上下文管理器。
- 返回类型:
ContextManager
示例
>>> with set_default_dtype(torch.bfloat16): >>> x = torch.tensor([1, 2, 3]) >>> x.dtype torch.bfloat16