CUDAGraph¶
- class torch.cuda.CUDAGraph[source][source]¶
CUDA 图的包装器。
警告
此 API 处于 Beta 阶段,将来版本中可能会发生变化。
- capture_begin(pool=None, capture_error_mode='global')[source][source]¶
开始在当前流上捕获 CUDA 工作。
通常情况下,您不应自行调用
capture_begin
。请使用graph
或make_graphed_callables()
,它们在内部调用capture_begin
。- 参数
pool (可选) – 令牌(由
torch.cuda.graph_pool_handle()
或other_Graph_instance.pool()
返回),提示此图可能与指定的池共享内存。参见图内存管理。capture_error_mode (str, 可选) – 指定图捕获流的 cudaStreamCaptureMode。可以是 “global”、“thread_local” 或 “relaxed”。在 CUDA 图捕获期间,某些操作(例如 cudaMalloc)可能不安全。“global” 会对其他线程中的操作报错,“thread_local” 只会针对当前线程中的操作报错,“relaxed” 不会针对这些操作报错。除非您熟悉 cudaStreamCaptureMode,否则请勿更改此设置。
- capture_end()[source][source]¶
结束在当前流上的 CUDA 图捕获。
在
capture_end
之后,可以在此实例上调用replay
。通常情况下,您不应自行调用
capture_end
。请使用graph
或make_graphed_callables()
,它们在内部调用capture_end
。