• 文件 >
  • 使用 TensorDict 簡化 PyTorch 記憶體管理
快捷方式

使用 TensorDict 簡化 PyTorch 記憶體管理

作者: Tom Begley

在本教程中,您將學習如何控制 TensorDict 的內容在記憶體中的儲存位置,可以透過將這些內容傳送到裝置,或者利用記憶體對映。

裝置

建立 TensorDict 時,可以使用 device 關鍵字引數指定裝置。如果設定了 device,則 TensorDict 的所有條目都將放置在該裝置上。如果未設定 device,則 TensorDict 中的條目不必在同一裝置上。

在此示例中,我們使用 device="cuda:0" 例項化一個 TensorDict。當我們列印其內容時,可以看到它們已被移至裝置。

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"a": torch.rand(10)}, [10], device="cuda:0")
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

如果 TensorDict 的裝置不是 None,新條目也會被移至該裝置。

>>> tensordict["b"] = torch.rand(10, 10)
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

您可以使用 device 屬性檢查 TensorDict 的當前裝置。

>>> print(tensordict.device)
cuda:0

TensorDict 的內容可以像 PyTorch 張量一樣傳送到裝置,使用 TensorDict.cuda()TensorDict.device(device),其中 device 是所需的裝置。

>>> tensordict.to(torch.device("cpu"))
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)
>>> tensordict.cuda()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

TensorDict.device 方法要求將有效的裝置作為引數傳入。如果您想從 TensorDict 中移除裝置以允許包含不同裝置的數值,應使用 TensorDict.clear_device 方法。

>>> tensordict.clear_device()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

記憶體對映張量

tensordict 提供了一個類 MemoryMappedTensor,它允許我們將張量的內容儲存在磁碟上,同時仍然支援快速索引和批次載入內容。請參閱 ImageNet 教程以瞭解實際示例。

要將 TensorDict 轉換為記憶體對映張量的集合,請使用 TensorDict.memmap_

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
tensordict.memmap_()

print(tensordict)
TensorDict(
    fields={
        a: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)

或者,可以使用 TensorDict.memmap_like 方法。這將建立一個具有相同結構的新 TensorDict,其值為 MemoryMappedTensor,但它不會將原始張量的內容複製到記憶體對映張量中。這允許您建立記憶體對映的 TensorDict,然後緩慢填充它,因此通常應優先於 memmap_ 使用此方法。

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
mm_tensordict = tensordict.memmap_like()

print(mm_tensordict["a"].contiguous())
MemoryMappedTensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

預設情況下,TensorDict 的內容將儲存到磁碟上的臨時位置,但是如果您想控制儲存位置,可以使用關鍵字引數 prefix="/path/to/root"

TensorDict 的內容儲存在一個模仿 TensorDict 本身結構的目錄結構中。張量的內容儲存在 NumPy 記憶體對映檔案中,元資料儲存在相關的 PyTorch 儲存檔案中。例如,上面的 TensorDict 會儲存如下:

├── a.memmap
├── a.meta.pt
├── b
│ ├── c.memmap
│ ├── c.meta.pt
│ └── meta.pt
└── meta.pt

指令碼總執行時間: (0 分 0.004 秒)

由 Sphinx-Gallery 生成的圖集

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源