快捷方式

graph

class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[source][source]

一个上下文管理器,用于将 CUDA 工作捕获到 torch.cuda.CUDAGraph 对象中,以便稍后重放。

有关一般的介绍、详细用法和限制,请参阅 CUDA Graphs

参数
  • cuda_graph (torch.cuda.CUDAGraph) – 用于捕获的图对象。

  • pool (可选) – 不透明令牌(由调用 graph_pool_handle()other_Graph_instance.pool() 返回),指示此图的捕获可以共享指定内存池的内存。请参阅 图内存管理

  • stream (torch.cuda.Stream, 可选) – 如果提供,将被设置为上下文中的当前流。如果未提供,则 graph 会将其自身的内部辅助流设置为上下文中的当前流。

  • capture_error_mode (str, 可选) – 指定图捕获流的 cudaStreamCaptureMode。可以是 “global”、“thread_local” 或 “relaxed”。在 CUDA 图捕获期间,某些操作(例如 cudaMalloc)可能不安全。“global” 会对其他线程中的操作报错,“thread_local” 只会对当前线程中的操作报错,“relaxed” 则不会对操作报错。除非您熟悉 cudaStreamCaptureMode,否则请勿更改此设置。

注意

为了有效的内存共享,如果您传入了先前捕获使用的 pool,并且先前的捕获使用了显式的 stream 参数,则您应该将相同的 stream 参数传入本次捕获。

警告

此 API 处于 Beta 阶段,在未来版本中可能会发生变化。

文档

查阅 PyTorch 的全面开发者文档

查看文档

教程

获取针对初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源