快捷方式

CudaGraphModule

tensordict.nn.CudaGraphModule(module: Callable[[Union[List[Tensor], TensorDictBase]], None], warmup: int = 2, in_keys: Optional[List[NestedKey]] = None, out_keys: Optional[List[NestedKey]] = None)

一個用於 PyTorch 可呼叫物件的 cudagraph 包裝器。

CudaGraphModule 是一個包裝類,它為 PyTorch 可呼叫物件提供了用於 CUDA graphs 的使用者友好介面。

警告

CudaGraphModule 是一個原型功能,其 API 限制未來可能會發生變化。

此類提供了使用者友好的 cudagraphs 介面,允許在 GPU 上快速執行操作,且沒有 CPU 開銷。它對函式的輸入進行基本檢查,並提供類似於 nn.Module 的 API 來執行

警告

此模組要求被包裝的函式滿足一些要求。使用者有責任確保所有這些要求都得到滿足。

  • 該函式不能有動態控制流。例如,以下程式碼片段將無法被包裝在 CudaGraphModule

    >>> def func(x):
    ...     if x.norm() > 1:
    ...         return x + 1
    ...     else:
    ...         return x - 1
    

    幸運的是,PyTorch 在大多數情況下都提供瞭解決方案

    >>> def func(x):
    ...     return torch.where(x.norm() > 1, x + 1, x - 1)
    
  • 該函式必須執行可以使用相同緩衝區精確重新執行的程式碼。這意味著不支援動態形狀(輸入或程式碼執行期間形狀發生變化)。換句話說,輸入必須具有常量形狀。

  • 函式的輸出必須是 detached(分離)的。如果需要呼叫最佳化器,請將其放在輸入函式中。例如,以下函式是一個有效的運算子

    >>> def func(x, y):
    ...     optim.zero_grad()
    ...     loss_val = loss_fn(x, y)
    ...     loss_val.backward()
    ...     optim.step()
    ...     return loss_val.detach()
    
  • 輸入不應該可微分。如果您需要使用 nn.Parameters(或通常是可微分的張量),只需編寫一個函式,將它們用作全域性值,而不是作為輸入傳遞

    >>> x = nn.Parameter(torch.randn(()))
    >>> optim = Adam([x], lr=1)
    >>> def func(): # right
    ...     optim.zero_grad()
    ...     (x+1).backward()
    ...     optim.step()
    >>> def func(x): # wrong
    ...     optim.zero_grad()
    ...     (x+1).backward()
    ...     optim.step()
    
  • 作為張量或 tensordict 的 args 和 kwargs 可以改變(前提是裝置和形狀匹配),但非張量的 args 和 kwargs 不應改變。例如,如果函式接收一個字串輸入,並且輸入在任何時候發生變化,模組將靜默地使用捕獲 cudagraph 時使用的字串執行程式碼。唯一支援的關鍵字引數是輸入為 tensordict 時的 tensordict_out

  • 如果模組是 TensorDictModuleBase 例項,並且輸出 id 與輸入 id 匹配,則在呼叫 CudaGraphModule 期間會保留此 identity。在所有其他情況下,無論其元素是否與其中一個輸入匹配,輸出都將被克隆。

警告

CudaGraphModule 的設計不是 Module,以避免收集輸入模組的引數並將其傳遞給最佳化器。

引數
  • module (可呼叫物件) – 一個函式,接收張量(或 tensordict)作為輸入,並輸出一個 PyTreeable 的張量集合。如果提供了 tensordict,模組也可以使用關鍵字引數執行(參見下面的示例)。

  • warmup (int, 可選) – 模組被編譯時的 warmup 步驟數(編譯後的模組在被 cudagraphs 捕獲之前應該執行幾次)。所有模組的預設值為 2

  • in_keys (NestedKeys 列表) –

    輸入鍵,如果模組以 TensorDict 作為輸入。如果 module.in_keys 存在,則預設為該值,否則為 None

    注意

    如果提供了 in_keys 但為空,則假定模組接收 tensordict 作為輸入。這足以讓 CudaGraphModule 知道該函式應被視為 TensorDictModule,但關鍵字引數將不會被分發。請參見下面的示例。

  • out_keys (NestedKeys 列表) – 輸出鍵,如果模組以 TensorDict 作為輸入並輸出 TensorDict。如果 module.out_keys 存在,則預設為該值,否則為 None

示例

>>> # Wrap a simple function
>>> def func(x):
...     return x + 1
>>> func = CudaGraphModule(func)
>>> x = torch.rand((), device='cuda')
>>> out = func(x)
>>> assert isinstance(out, torch.Tensor)
>>> assert out == x+1
>>> # Wrap a tensordict module
>>> func = TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"])
>>> func = CudaGraphModule(func)
>>> # This can be called either with a TensorDict or regular keyword arguments alike
>>> y = func(x=x)
>>> td = TensorDict(x=x)
>>> td = func(td)

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源