快捷方式

torch.cuda.make_graphed_callables

torch.cuda.make_graphed_callables(callables, sample_args, num_warmup_iters=3, allow_unused_input=False, pool=None)[源代码]

接受可调用对象(函数或 nn.Module)并返回其图形化版本。

每个图形化可调用对象的正向传递都在单个 autograd 节点内将其源可调用对象的正向 CUDA 工作作为 CUDA 图运行。

图形化可调用对象的正向传递还会将反向节点附加到 autograd 图中。在反向传播期间,此节点将可调用对象的反向工作作为 CUDA 图运行。

因此,每个图形化可调用对象都应该是其源可调用对象在启用 autograd 的训练循环中的直接替换。

有关详细的使用方法和约束条件,请参阅 部分网络捕获

如果传递多个可调用对象的元组,则它们的捕获将使用相同的内存池。有关何时适合这样做,请参阅 图形内存管理

参数
  • callables (torch.nn.ModulePython 函数元组 其中包含这些) – 要进行图形化的可调用对象或可调用对象。有关何时适合传递可调用对象的元组,请参阅 图形内存管理。如果传递可调用对象的元组,则元组中可调用对象的顺序必须与它们在实际工作负载中运行的顺序相同。

  • sample_args (元组 其中包含张量元组 其中包含张量的元组) – 每个可调用对象的示例参数。如果传递单个可调用对象,则 sample_args 必须是参数张量的单个元组。如果传递可调用对象的元组,则 sample_args 必须是参数张量元组的元组。

  • num_warmup_iters (int) – 预热迭代次数。目前,DataDistributedParallel 需要 11 次迭代进行预热。默认值:3

  • allow_unused_input (bool) – 如果为 False,则指定在计算输出时未使用(因此其梯度始终为零)的输入将导致错误。默认为 False。

  • pool (可选) – 令牌(由 graph_pool_handle()other_Graph_instance.pool() 返回),提示此图形可能与指示的池共享内存。有关何时适合这样做,请参阅 图形内存管理

注意

sample_args 中每个张量的 requires_grad 状态必须与训练循环中相应真实输入的预期状态匹配。

警告

此 API 处于测试阶段,可能会在将来的版本中发生更改。

警告

每个可调用对象的 sample_args 只能包含张量。不允许使用其他类型。

警告

返回的可调用对象不支持高阶微分(例如,双重反向传播)。

警告

在传递给 make_graphed_callables() 的任何 Module 中,只有参数可以是可训练的。缓冲区必须具有 requires_grad=False

警告

在通过 make_graphed_callables() 传递 torch.nn.Module 后,您不能添加或删除该模块的任何参数或缓冲区。

警告

传递给 torch.nn.Modulemake_graphed_callables() 的模块,在传递时不得在其上注册模块钩子。但是,在通过 make_graphed_callables() 后,允许在模块上注册钩子。

警告

运行图化可调用对象时,必须按其可调用对象的 sample_args 中出现的相同顺序和格式传递其参数。

警告

仅当禁用缓存时,make_graphed_callables() 中才支持自动混合精度。上下文管理器 torch.cuda.amp.autocast() 必须具有 cache_enabled=False

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源