torch.jit.fork¶
- torch.jit.fork(func, *args, **kwargs)[source][source]¶
创建一个异步任务来执行 func,并返回对执行结果值的引用。
fork 会立即返回,因此 func 的返回值可能尚未计算完成。要强制完成任务并访问返回值,请在 Future 上调用 torch.jit.wait。使用返回 T 的 func 调用的 fork 类型为 torch.jit.Future[T]。fork 调用可以任意嵌套,并且可以使用位置参数和关键字参数调用。异步执行仅在 TorchScript 中运行时才会发生。如果在纯 Python 中运行,fork 不会并行执行。fork 在跟踪时调用也不会并行执行,但是 fork 和 wait 调用将捕获在导出的 IR 图中。
警告
fork 任务将以非确定性的方式执行。我们建议仅为不修改其输入、模块属性或全局状态的纯函数生成并行 fork 任务。
- 参数
func (可调用对象 或 torch.nn.Module) – 将被调用的 Python 函数或 torch.nn.Module。如果在 TorchScript 中执行,它将异步执行,否则不会。 fork 的跟踪调用将被捕获在 IR 中。
*args – 调用 func 时使用的参数。
**kwargs – 调用 func 时使用的参数。
- 返回
对 func 执行的引用。值 T 只能通过强制完成 func 并通过 torch.jit.wait 来访问。
- 返回类型
torch.jit.Future[T]
示例(fork 一个自由函数)
import torch from torch import Tensor def foo(a : Tensor, b : int) -> Tensor: return a + b def bar(a): fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2) return torch.jit.wait(fut) script_bar = torch.jit.script(bar) input = torch.tensor(2) # only the scripted version executes asynchronously assert script_bar(input) == bar(input) # trace is not run asynchronously, but fork is captured in IR graph = torch.jit.trace(bar, (input,)).graph assert "fork" in str(graph)
示例(fork 一个模块方法)
import torch from torch import Tensor class AddMod(torch.nn.Module): def forward(self, a: Tensor, b : int): return a + b class Mod(torch.nn.Module): def __init__(self) -> None: super(self).__init__() self.mod = AddMod() def forward(self, input): fut = torch.jit.fork(self.mod, a, b=2) return torch.jit.wait(fut) input = torch.tensor(2) mod = Mod() assert mod(input) == torch.jit.script(mod).forward(input)