• 文件 >
  • TensorDict 在分散式設定中
快捷方式

TensorDict 在分散式設定中

TensorDict 可用於分散式設定,以便在不同節點之間傳遞張量。如果兩個節點可以訪問共享物理儲存,則可以使用記憶體對映張量高效地在正在執行的不同程序之間傳遞資料。本文件提供了一些關於如何在分散式 RPC 環境中實現這一目標的詳細資訊。有關分散式 RPC 的更多詳細資訊,請查閱 官方 PyTorch 文件

建立記憶體對映 TensorDict

記憶體對映張量(和陣列)的一大優勢在於它們可以儲存大量資料,並允許快速訪問資料切片,而無需將整個檔案讀入記憶體。TensorDict 在記憶體對映陣列和 torch.Tensor 類之間提供了介面,該介面名為 MemmapTensorMemmapTensor 例項可以儲存在 TensorDict 物件中,從而允許 tensordict 表示儲存在磁碟上的大型資料集,並且可以輕鬆地跨節點進行批處理訪問。

建立記憶體對映 tensordict 非常簡單,可以透過 (1) 用記憶體對映張量填充 TensorDict 或 (2) 呼叫 tensordict.memmap_() 將其放置到物理儲存上。可以透過查詢 tensordict.is_memmap() 輕鬆檢查 tensordict 是否已放置在物理儲存上。

建立記憶體對映張量本身可以透過多種方式完成。首先,可以簡單地建立一個空張量

>>> shape = torch.Size([3, 4, 5])
>>> tensor = Memmaptensor(*shape, prefix="/tmp")
>>> tensor[:2] = torch.randn(2, 4, 5)

prefix 屬性指示臨時檔案應儲存的位置。至關重要的是,張量必須儲存在每個節點都可以訪問的目錄中!

另一種選擇是表示磁碟上的現有張量

>>> tensor = torch.randn(3)
>>> tensor = Memmaptensor(tensor, prefix="/tmp")

前一種方法將在張量很大或不適合記憶體時更受歡迎:它適用於非常大的張量,並作為跨節點的通用儲存。例如,可以建立一個數據集,該資料集可以由單個或不同節點輕鬆訪問,比每個檔案都必須獨立載入到記憶體中要快得多

在磁碟上建立空資料集
>>> dataset = TensorDict({
...      "images": MemmapTensor(50000, 480, 480, 3),
...      "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool),
...      "labels": MemmapTensor(50000, 1, dtype=torch.uint8),
... }, batch_size=[50000], device="cpu")
>>> idx = [1, 5020, 34572, 11200]
>>> batch = dataset[idx].clone()
TensorDict(
    fields={
        images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32),
        labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8),
        masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

請注意,我們已經指明瞭 MemmapTensor 的裝置。這種語法糖允許在需要時直接將查詢到的張量載入到裝置上。

需要考慮的另一個問題是,目前 MemmapTensor 不相容 autograd 操作。

跨節點操作記憶體對映張量

我們提供了一個簡單的分散式指令碼示例,其中一個程序建立一個記憶體對映張量,並將其引用傳送給另一個負責更新它的工作程序。您可以在 benchmark 目錄中找到此示例。

簡而言之,我們的目標是展示當節點可以訪問共享物理儲存時,如何處理大型張量的讀寫操作。步驟包括

  • 在磁碟上建立空張量;

  • 設定要執行的本地和遠端操作;

  • 使用 RPC 在工作程序之間傳遞命令以讀寫共享資料。

這個示例首先編寫一個函式,該函式使用一個填充了 1 的張量在特定索引處更新 TensorDict 例項

>>> def fill_tensordict(tensordict, idx):
...     tensordict[idx] = TensorDict(
...         {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
...     )
...     return tensordict
>>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)

CloudpickleWrapper 確保函式是可序列化的。接下來,我們建立一個相當大尺寸的 tensordict,以說明如果必須透過常規 tensorpipe 傳遞,這將很難在工作程序之間傳遞。

>>> tensordict = TensorDict(
...     {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000]
... )

最後,仍在主節點上,我們在遠端節點上呼叫該函式,然後檢查資料是否已寫入所需位置

>>> idx = [4, 5, 6, 7, 998]
>>> t0 = time.time()
>>> out = rpc.rpc_sync(
...     worker_info,
...     fill_tensordict_cp,
...     args=(tensordict, idx),
... )
>>> print("time elapsed:", time.time() - t0)
>>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone())

儘管呼叫 rpc.rpc_sync 涉及傳遞整個 tensordict,更新此物件的特定索引並將其返回給原始工作程序,但此程式碼段的執行速度非常快(如果事先已傳遞記憶體位置的引用,速度更快,請參閱 torchrl 的分散式重放緩衝區文件瞭解更多資訊)。

該指令碼包含本文件目的之外的其他 RPC 配置步驟。

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發人員的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源