graph¶
- class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[source][source]¶
一個上下文管理器,用於將 CUDA 工作捕獲到
torch.cuda.CUDAGraph物件中,以便稍後重放。有關一般的介紹、詳細用法和限制,請參閱 CUDA Graphs。
- 引數
cuda_graph (torch.cuda.CUDAGraph) – 用於捕獲的圖物件。
pool (可選) – 不透明令牌(由呼叫
graph_pool_handle()或other_Graph_instance.pool()返回),指示此圖的捕獲可以共享指定記憶體池的記憶體。請參閱 圖記憶體管理。stream (torch.cuda.Stream, 可選) – 如果提供,將被設定為上下文中的當前流。如果未提供,則
graph會將其自身的內部輔助流設定為上下文中的當前流。capture_error_mode (str, 可選) – 指定圖捕獲流的 cudaStreamCaptureMode。可以是 “global”、“thread_local” 或 “relaxed”。在 CUDA 圖捕獲期間,某些操作(例如 cudaMalloc)可能不安全。“global” 會對其他執行緒中的操作報錯,“thread_local” 只會對當前執行緒中的操作報錯,“relaxed” 則不會對操作報錯。除非您熟悉 cudaStreamCaptureMode,否則請勿更改此設定。
注意
為了有效的記憶體共享,如果您傳入了先前捕獲使用的
pool,並且先前的捕獲使用了顯式的stream引數,則您應該將相同的stream引數傳入本次捕獲。警告
此 API 處於 Beta 階段,在未來版本中可能會發生變化。