graph¶
- class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[源代码]¶
上下文管理器,将 CUDA 工作捕获到
torch.cuda.CUDAGraph
对象中以供稍后重放。有关一般介绍、详细使用和约束,请参阅 CUDA 图。
- 参数
cuda_graph (torch.cuda.CUDAGraph) – 用于捕获的图对象。
pool (可选) – 不透明令牌(由调用
graph_pool_handle()
或other_Graph_instance.pool()
返回),提示此图的捕获可能会与指定池共享内存。请参阅 图内存管理。stream (torch.cuda.Stream, 可选) – 如果提供,将被设置为上下文中的当前流。如果未提供,
graph
会将其自己的内部侧流设置为上下文中的当前流。capture_error_mode (str, 可选) – 指定图捕获流的 cudaStreamCaptureMode。可以是“global”、“thread_local”或“relaxed”。在 cuda 图捕获期间,某些操作(例如 cudaMalloc)可能是危险的。“global”将在其他线程中的操作上出错,“thread_local”只会针对当前线程中的操作出错,“relaxed”不会针对操作出错。除非您熟悉 cudaStreamCaptureMode,否则请勿更改此设置。
注意
为了有效地共享内存,如果您传递了先前捕获使用的
pool
,并且先前捕获使用了显式的stream
参数,您应该将相同的stream
参数传递给此捕获。警告
此 API 处于测试阶段,可能会在将来的版本中发生更改。