快捷方式

graph

class torch.cuda.graph(cuda_graph, pool=None, stream=None, capture_error_mode='global')[source][source]

一個上下文管理器,用於將 CUDA 工作捕獲到 torch.cuda.CUDAGraph 物件中,以便稍後重放。

有關一般的介紹、詳細用法和限制,請參閱 CUDA Graphs

引數
  • cuda_graph (torch.cuda.CUDAGraph) – 用於捕獲的圖物件。

  • pool (可選) – 不透明令牌(由呼叫 graph_pool_handle()other_Graph_instance.pool() 返回),指示此圖的捕獲可以共享指定記憶體池的記憶體。請參閱 圖記憶體管理

  • stream (torch.cuda.Stream, 可選) – 如果提供,將被設定為上下文中的當前流。如果未提供,則 graph 會將其自身的內部輔助流設定為上下文中的當前流。

  • capture_error_mode (str, 可選) – 指定圖捕獲流的 cudaStreamCaptureMode。可以是 “global”、“thread_local” 或 “relaxed”。在 CUDA 圖捕獲期間,某些操作(例如 cudaMalloc)可能不安全。“global” 會對其他執行緒中的操作報錯,“thread_local” 只會對當前執行緒中的操作報錯,“relaxed” 則不會對操作報錯。除非您熟悉 cudaStreamCaptureMode,否則請勿更改此設定。

注意

為了有效的記憶體共享,如果您傳入了先前捕獲使用的 pool,並且先前的捕獲使用了顯式的 stream 引數,則您應該將相同的 stream 引數傳入本次捕獲。

警告

此 API 處於 Beta 階段,在未來版本中可能會發生變化。

文件

查閱 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源