快捷方式

序列化語義

本說明描述了如何在 Python 中保存和加載 PyTorch 張量和模組狀態,以及如何序列化 Python 模組以便在 C++ 中加載它們。

保存和加載張量

torch.save()torch.load() 允許您輕鬆地保存和加載張量

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

按照慣例,PyTorch 文件通常使用“.pt”或“.pth”擴展名編寫。

torch.save()torch.load() 默認使用 Python 的 pickle,因此您也可以將多個張量保存為 Python 對象的一部分,例如元組、列表和字典

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

如果自訂數據結構是可 pickle 的,則也可以保存包含 PyTorch 張量的自訂數據結構。

保存和加載張量會保留視圖

保存張量會保留它們的視圖關係

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

在幕後,這些張量共享相同的“存儲”。有關視圖和存儲的更多信息,請參見 張量視圖

當 PyTorch 保存張量時,它會分別保存它們的存儲對象和張量元數據。這是一個將來可能會改變的實現細節,但它通常可以節省空間,並讓 PyTorch 輕鬆地重建加載的張量之間的視圖關係。例如,在上面的代碼段中,只有一個存儲被寫入“tensors.pt”。

然而,在某些情況下,保存當前的存儲對象可能是不必要的,並且會創建過大的文件。在下面的代碼段中,一個比保存的張量大得多的存儲被寫入文件

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

沒有將 small 張量中的五個值保存到“small.pt”,而是保存並加載了它與 large 共享的存儲中的 999 個值。

當保存的張量元素少於其存儲對象時,可以通過首先克隆張量來減小保存的文件的大小。克隆張量會生成一個新的張量,其中包含一個新的存儲對象,該對象僅包含張量中的值

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

然而,由於克隆的張量彼此獨立,因此它們不具有原始張量所具有的任何視圖關係。如果在保存小於其存儲對象的張量時,文件大小和視圖關係都很重要,那麼在保存之前,必須小心地構造新的張量,以最小化其存儲對象的大小,但仍然具有所需的視圖關係。

保存和加載 torch.nn.Modules

另見:教學:保存和加載模組

在 PyTorch 中,模組的狀態經常使用“狀態字典”進行序列化。模組的狀態字典包含其所有參數和持久緩衝區

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

出於兼容性原因,建議不要直接保存模組,而是僅保存其狀態字典。Python 模組甚至有一個函數 load_state_dict(),用於從狀態字典恢復其狀態

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>

請注意,狀態字典首先使用 torch.load() 從其文件中加載,然後使用 load_state_dict() 恢復狀態。

即使是自訂模組和包含其他模組的模組也具有狀態字典,並且可以使用這種模式

# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

`torch.save` 的序列化文件格式

從 PyTorch 1.6.0 開始,除非用戶設置 _use_new_zipfile_serialization=False,否則 torch.save 默認返回未壓縮的 ZIP64 存檔。

在此存檔中,文件的順序如下

checkpoint.pth
├── data.pkl
├── byteorder  # added in PyTorch 2.1.0
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
條目如下
  • data.pkl 是將傳遞給 torch.save 的對象(不包括它包含的 torch.Storage 對象)進行 pickle 的結果

  • byteorder 包含一個字符串,其中包含保存時的 sys.byteorder(“little”或“big”)

  • data/ 包含對象中的所有存儲,其中每個存儲都是一個單獨的文件

  • version 包含儲存時的版本號,可在載入時使用

儲存時,PyTorch 會確保每個檔案的本地檔案標頭都被填充到 64 位元組的倍數偏移量,確保每個檔案的偏移量都是 64 位元組對齊。

備註

某些裝置(例如 XLA)上的張量會被序列化為 pickled 的 numpy 陣列。 因此,它們的儲存體不會被序列化。 在這些情況下,檢查點中可能不存在 data/

序列化 torch.nn.Modules 並在 C++ 中載入

另請參閱:教學課程:在 C++ 中載入 TorchScript 模型

ScriptModules 可以序列化為 TorchScript 程式,並使用 torch.jit.load() 載入。 此序列化編碼了所有模組的方法、子模組、參數和屬性,並且允許在 C++ 中載入序列化程式(即,沒有 Python)。

torch.jit.save()torch.save() 之間的區別可能並非一目了然。 torch.save() 使用 pickle 儲存 Python 物件。 這對於原型設計、研究和訓練特別有用。 另一方面,torch.jit.save() 將 ScriptModules 序列化為可以在 Python 或 C++ 中載入的格式。 這在儲存和載入 C++ 模組時,或者在使用 C++ 運行在 Python 中訓練的模組時非常有用,這是部署 PyTorch 模型時的常見做法。

要在 Python 中腳本化、序列化和載入模組

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
                      (l0): RecursiveScriptModule(original_name=Linear)
                      (l1): RecursiveScriptModule(original_name=Linear) )

追蹤的模組也可以使用 torch.jit.save() 儲存,但需要注意的是,只有追蹤的程式碼路徑會被序列化。 以下範例說明了這一點

# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

上述模組有一個 if 語句,該語句不會被追蹤的輸入觸發,因此不是追蹤模組的一部分,也不會與其一起序列化。 但是,腳本化的模組包含 if 語句,並且與其一起序列化。 有關腳本化和追蹤的更多資訊,請參閱 TorchScript 文件

最後,要在 C++ 中載入模組

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

有關如何在 C++ 中使用 PyTorch 模組的詳細資訊,請參閱 PyTorch C++ API 文件

跨 PyTorch 版本儲存和載入 ScriptModules

PyTorch 團隊建議使用相同版本的 PyTorch 儲存和載入模組。 舊版本的 PyTorch 可能不支援較新的模組,而較新版本可能已移除或修改了舊的行為。 這些變更在 PyTorch 的 發行說明 中有明確說明,並且依賴於已變更功能的模組可能需要更新才能繼續正常運作。 在有限的情況下(如下所述),PyTorch 將保留序列化 ScriptModules 的歷史行為,因此它們不需要更新。

torch.div 執行整數除法

在 PyTorch 1.5 和更早版本中,當給定兩個整數輸入時,torch.div() 會執行向下取整除法

# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

但是,在 PyTorch 1.7 中,torch.div() 將始終執行其輸入的真實除法,就像 Python 3 中的除法一樣

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

torch.div() 的行為在序列化 ScriptModules 中得以保留。 也就是說,使用 1.6 之前版本的 PyTorch 序列化的 ScriptModules 將繼續在給定兩個整數輸入時看到 torch.div() 執行向下取整除法,即使使用較新版本的 PyTorch 載入也是如此。 但是,使用 torch.div() 並在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 無法在早期版本的 PyTorch 中載入,因為這些早期版本不理解新的行為。

torch.full 總是推斷出 float dtype

在 PyTorch 1.5 和更早版本中,torch.full() 總是返回一個 float 張量,而不管給定它的填充值是什麼

# PyTorch 1.5 and earlier
>>> torch.full((3,), 1)  # Note the integer fill value...
tensor([1., 1., 1.])     # ...but float tensor!

但是,在 PyTorch 1.7 中,torch.full() 將從填充值推斷返回張量的 dtype

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])

torch.full() 的行為在序列化 ScriptModules 中得以保留。 也就是說,使用 1.6 之前版本的 PyTorch 序列化的 ScriptModules 將繼續看到 torch.full 預設返回 float 張量,即使給定 bool 或整數填充值也是如此。 但是,使用 torch.full() 並在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 無法在早期版本的 PyTorch 中載入,因為這些早期版本不理解新的行為。

實用函數

以下實用函數與序列化相關

torch.serialization.register_package(priority, tagger, deserializer)[原始碼]

註冊可調用函數,用於標記和反序列化儲存體物件以及關聯的優先級。 標記在儲存時將裝置與儲存體物件關聯,而反序列化則在載入時將儲存體物件移動到適當的裝置。 taggerdeserializer 按照其 priority 給定的順序運行,直到標記器/反序列化器返回的值不是 None

要覆蓋全局註冊表中裝置的反序列化行為,可以使用比現有標記器更高的優先級註冊標記器。

此函數還可以用於為新裝置註冊標記器和反序列化器。

參數
返回

範例

>>> def ipu_tag(obj):
>>>     if obj.device.type == 'ipu':
>>>         return 'ipu'
>>> def ipu_deserialize(obj, location):
>>>     if location.startswith('ipu'):
>>>         ipu = getattr(torch, "ipu", None)
>>>         assert ipu is not None, "IPU device module is not loaded"
>>>         assert torch.ipu.is_available(), "ipu is not available"
>>>         return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
torch.serialization.get_default_load_endianness()[原始碼]

取得載入檔案的後備位元組順序

如果儲存的檢查點中不存在位元組順序標記,則使用此位元組順序作為後備。 預設情況下,它是「原生」位元組順序。

返回

Optional[LoadEndianness]

返回類型

default_load_endian

torch.serialization.set_default_load_endianness(endianness)[原始碼]

設定載入檔案的後備位元組順序

如果儲存的檢查點中不存在位元組順序標記,則使用此位元組順序作為後備。 預設情況下,它是「原生」位元組順序。

參數

endianness – 新的後備位元組順序

torch.serialization.get_default_mmap_options()[source]

取得 mmap=Truetorch.load() 的預設 mmap 選項。

預設為 mmap.MAP_PRIVATE

返回

int

返回類型

default_mmap_options

torch.serialization.set_default_mmap_options(flags)[source]

mmap=Truetorch.load() 的預設 mmap 選項設定為 flags。

目前僅支援 mmap.MAP_PRIVATEmmap.MAP_SHARED。如果您需要新增任何其他選項,請提出 issue。

備註

Windows 目前不支援此功能。

參數

flags (int) – mmap.MAP_PRIVATEmmap.MAP_SHARED

torch.serialization.add_safe_globals(safe_globals)[source]

將指定的全域變數標記為對 weights_only 載入是安全的。例如,新增到此清單中的函數可以在反序列化期間被呼叫,類別可以被實例化並設定狀態。

參數

safe_globals (List[Any]) – 要標記為安全的全域變數清單

範例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     torch.serialization.add_safe_globals([MyTensor])
...     torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
torch.serialization.clear_safe_globals()[source]

清除對 weights_only 載入安全的全域變數清單。

torch.serialization.get_safe_globals()[source]

傳回使用者新增的、對 weights_only 載入安全的全域變數清單。

返回類型

List[Any]

文件

取得 PyTorch 的完整開發者文件

查看文件

教學課程

取得適用於初學者和進階開發者的深入教學課程

查看教學課程

資源

尋找開發資源並獲得問題解答

查看資源