歡迎來到 TensorDict 文件!¶
TensorDict 是一個類似字典的類,繼承了張量的屬性,例如索引、形狀操作、轉換到裝置等。
你可以直接從 PyPI 安裝 tensordict(更多安裝說明請參閱下面的專門章節)
$ pip install tensordict
TensorDict 的主要目的是透過抽象定製化操作,使程式碼庫更具可讀性和模組化。
>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
... tensordict = model(tensordict)
... loss = loss_module(tensordict)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
透過這種抽象級別,可以針對高度異構的任務複用訓練迴圈。訓練迴圈的每個獨立步驟(資料收集和轉換、模型預測、損失計算等)都可以根據當前用例進行定製,而不會影響其他步驟。例如,上述示例可以輕鬆地用於分類和分割任務,以及許多其他任務。
安裝¶
Tensordict 的釋出與 PyTorch 同步,因此請確保始終使用最新版本的 PyTorch 來體驗庫的最新功能(儘管核心功能保證向後相容 pytorch>=1.13)。每夜構建版本可以透過以下方式安裝
$ pip install tensordict-nightly
或者如果你願意為庫做貢獻,也可以透過 git clone 獲取
$ cd path/to/root
$ git clone https://github.com/pytorch/tensordict
$ cd tensordict
$ python setup.py develop