快捷方式

tensordict 包

The TensorDict class simplifies the process of passing multiple tensors from module to module by packing them in a dictionary-like object that inherits features from regular pytorch tensors.

TensorDictBase()

TensorDictBase 是 TensorDict 的抽象父類,TensorDict 是一種 torch.Tensor 資料容器。

TensorDict([source, batch_size, device, ...])

張量的批次字典。

LazyStackedTensorDict(*tensordicts[, ...])

TensorDict 的惰性堆疊。

PersistentTensorDict(*[, batch_size, ...])

持久化 TensorDict 實現。

TensorDictParams([parameters, no_convert, lock])

帶有引數暴露功能的 TensorDictBase 包裝器。

get_defaults_to_none([set_to_none])

返回 get 預設值的狀態。

建構函式和處理程式

該庫提供了一些方法來與 numpy 結構化陣列、namedtuple 或 h5 檔案等其他資料結構進行互動。該庫還提供了專門的函式來操作 tensordict,例如 saveloadstackcat

cat(input[, dim, out])

沿著給定維度將 tensordict 連線成單個 tensordict。

default_is_leaf(cls)

如果一個型別不是張量集合(tensordict 或 tensorclass),則返回 True

from_any(obj, *[, auto_batch_size, ...])

將任意物件轉換為 TensorDict。

from_consolidated(filename)

從合併檔案中重建 tensordict。

from_dict(d, *[, auto_batch_size, ...])

將字典轉換為 TensorDict。

from_h5(h5_file, *[, auto_batch_size, ...])

將 HDF5 檔案轉換為 TensorDict。

from_module(module[, as_module, lock, ...])

將模組的引數和緩衝區複製到 tensordict 中。

from_modules(*modules[, as_module, lock, ...])

透過 vmap 獲取多個模組的引數,用於整合學習/期望應用的特性。

from_namedtuple(named_tuple, *[, ...])

將 namedtuple 轉換為 TensorDict。

from_pytree(pytree, *[, batch_size, ...])

將 pytree 轉換為 TensorDict 例項。

from_struct_array(struct_array, *[, ...])

將結構化 numpy 陣列轉換為 TensorDict。

from_tuple(obj, *[, auto_batch_size, ...])

將元組轉換為 TensorDict。

fromkeys(keys[, value])

從鍵列表和單個值建立 tensordict。

is_batchedtensor(arg0)

is_leaf_nontensor(cls)

如果一個型別不是張量集合(tensordict 或 tensorclass)或不是張量,則返回 True

lazy_stack(input[, dim, out])

建立 tensordict 的惰性堆疊。

load(prefix[, device, non_blocking, out])

從磁碟載入 tensordict。

load_memmap(prefix[, device, non_blocking, out])

從磁碟載入記憶體對映的 tensordict。

maybe_dense_stack(input[, dim, out])

嘗試對 tensordict 進行密集堆疊,並在需要時回退到惰性堆疊。

memmap(data[, prefix, copy_existing, ...])

將所有張量寫入新 tensordict 中的相應記憶體對映張量。

save(data[, prefix, copy_existing, ...])

將 tensordict 儲存到磁碟。

stack(input[, dim, out])

沿著給定維度將 tensordict 堆疊成單個 tensordict。

將 TensorDict 用作上下文管理器

TensorDict 可以在需要執行某個操作然後撤銷該操作的情況下用作上下文管理器。這包括臨時鎖定/解鎖 tensordict

>>> data.lock_()  # data.set will result in an exception
>>> with data.unlock_():
...     data.set("key", value)
>>> assert data.is_locked()

或使用包含模型引數和緩衝區的 TensorDict 例項執行函式呼叫

>>> params = TensorDict.from_module(module).clone()
>>> params.zero_()
>>> with params.to_module(module):
...     y = module(x)

在第一個示例中,我們可以修改 tensordict data,因為我們臨時解鎖了它。在第二個示例中,我們使用 params tensordict 例項中包含的引數和緩衝區填充模組,並在呼叫完成後重置原始引數。

記憶體對映張量

tensordict 提供了 MemoryMappedTensor 原語,它允許您方便地處理儲存在物理記憶體中的張量。MemoryMappedTensor 的主要優點包括易於構建(無需處理張量的儲存)、處理不適合記憶體的大塊連續資料的能力、跨程序的高效序列化/反序列化以及儲存張量的高效索引。

如果所有工作程序都可以訪問相同的儲存(多程序和分散式設定中均如此),傳遞 MemoryMappedTensor 僅需傳遞磁碟上檔案的引用以及用於重建它的一堆額外元資料。只要索引記憶體對映張量的儲存資料指標與原始資料指標相同,情況也是如此。

索引記憶體對映張量比從磁碟載入多個獨立檔案快得多,並且不需要將整個陣列內容載入到記憶體中。但是,PyTorch 張量的物理儲存應該沒有區別

>>> my_images = MemoryMappedTensor.empty((1_000_000, 3, 480, 480), dtype=torch.unint8)
>>> mini_batch = my_images[:10]  # just reads the first 10 images of the dataset

MemoryMappedTensor(source, *[, dtype, ...])

記憶體對映張量。

逐點操作

Tensordict 支援各種逐點操作,允許您對其內部儲存的張量執行元素級計算。這些操作與常規 PyTorch 張量上的操作類似。

支援的操作

目前支援以下逐點操作

  • 左加和右加 (+)

  • 左減和右減 (-)

  • 左乘和右乘 (*)

  • 左除和右除 (/)

  • 左乘方 (**)

還支援許多其他操作,例如 clamp()sqrt() 等。

執行逐點操作

您可以在兩個 Tensordict 之間或在 Tensordict 與張量/標量值之間執行逐點操作。

示例 1:Tensordict-Tensordict 操作

>>> import torch
>>> from tensordict import TensorDict
>>> td1 = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> td2 = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> result = td1 * td2

在此示例中,* 運算子被逐元素地應用於 td1 和 td2 中對應的張量。

示例 2:Tensordict-張量操作

>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> tensor = torch.randn(4)
>>> result = td * tensor

在這裡,* 運算子被逐元素地應用於 td 中的每個張量和提供的張量。該張量會被廣播以匹配 Tensordict 中每個張量的形狀。

示例 3:Tensordict-標量操作

>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
...     a=torch.randn(3, 4),
...     b=torch.zeros(3, 4, 5),
...     c=torch.ones(3, 4, 5, 6),
...     batch_size=(3, 4),
... )
>>> scalar = 2.0
>>> result = td * scalar

在這種情況下,* 運算子被逐元素地應用於 td 中的每個張量和提供的標量。

廣播規則

當在 Tensordict 和張量/標量之間執行逐點操作時,張量/標量會被廣播以匹配 Tensordict 中每個張量的形狀:張量在左側被廣播以匹配 tensordict 的形狀,然後在右側單獨廣播以匹配張量的形狀。如果將 TensorDict 視為單個張量例項,這遵循 PyTorch 中使用的標準廣播規則。

例如,如果您有一個包含形狀為 (3, 4) 的張量的 Tensordict,並將其乘以形狀為 (4,) 的張量,該張量在應用操作之前將被廣播為形狀 (3, 4)。如果 tensordict 包含一個形狀為 (3, 4, 5) 的張量,用於乘法的張量在該乘法中將在右側廣播為 (3, 4, 5)

如果在多個 tensordict 之間執行逐點操作且它們的批處理大小不同,它們將被廣播到公共形狀。

逐點操作的效率

如果可能,將使用 torch._foreach_<op> 融合核函式來加速逐點操作的計算。

處理缺失條目

當在兩個 Tensordict 之間執行逐點操作時,它們必須具有相同的鍵。某些操作,如 add(),具有一個 default 關鍵字引數,可用於處理具有獨佔條目的 tensordict。如果 default=None(預設值),則兩個 Tensordict 必須具有完全匹配的鍵集。如果 default="intersection",則僅考慮相交的鍵集,而忽略其他鍵。在所有其他情況下,default 將用於操作兩側所有缺失的條目。

工具函式

utils.expand_as_right(tensor, dest)

在右側擴充套件張量以匹配另一個張量的形狀。

utils.expand_right(tensor, shape)

在右側擴充套件張量以匹配所需的形狀。

utils.isin(input, reference, key[, dim])

測試輸入中 dim 維度上 key 的每個元素是否存在於參考中。

utils.remove_duplicates(input, key[, dim, ...])

移除指定維度上 key 中重複的索引。

is_batchedtensor(arg0)

is_tensor_collection(datatype)

檢查資料物件或型別是否是來自 tensordict 庫的張量容器。

make_tensordict([input_dict, batch_size, ...])

返回從關鍵字引數或輸入字典建立的 TensorDict。

merge_tensordicts(*tensordicts[, callback_exist])

合併 tensordict。

pad(tensordict, pad_size[, value])

使用常量值沿批次維度填充 tensordict 中的所有張量,並返回一個新的 tensordict。

pad_sequence(list_of_tensordicts[, pad_dim, ...])

填充 tensordict 列表,以便將它們以連續格式堆疊在一起。

dense_stack_tds(td_list[, dim])

密集堆疊具有相同結構的 TensorDictBase 物件列表(或 LazyStackedTensorDict)。

set_lazy_legacy(mode)

將某些方法的行為設定為惰性轉換。

lazy_legacy([allow_none])

如果對選定方法使用惰性表示,則返回 True

parse_tensor_dict_string(s)

將 TensorDict repr 解析為 TensorDict。

文件

訪問 PyTorch 的完整開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得解答

檢視資源