序列化語義¶
本說明描述了如何在 Python 中保存和加載 PyTorch 張量和模組狀態,以及如何序列化 Python 模組以便在 C++ 中加載它們。
目錄
保存和加載張量¶
torch.save() 和 torch.load() 允許您輕鬆地保存和加載張量
>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])
按照慣例,PyTorch 文件通常使用“.pt”或“.pth”擴展名編寫。
torch.save() 和 torch.load() 默認使用 Python 的 pickle,因此您也可以將多個張量保存為 Python 對象的一部分,例如元組、列表和字典
>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
如果自訂數據結構是可 pickle 的,則也可以保存包含 PyTorch 張量的自訂數據結構。
保存和加載張量會保留視圖¶
保存張量會保留它們的視圖關係
>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])
在幕後,這些張量共享相同的“存儲”。有關視圖和存儲的更多信息,請參見 張量視圖。
當 PyTorch 保存張量時,它會分別保存它們的存儲對象和張量元數據。這是一個將來可能會改變的實現細節,但它通常可以節省空間,並讓 PyTorch 輕鬆地重建加載的張量之間的視圖關係。例如,在上面的代碼段中,只有一個存儲被寫入“tensors.pt”。
然而,在某些情況下,保存當前的存儲對象可能是不必要的,並且會創建過大的文件。在下面的代碼段中,一個比保存的張量大得多的存儲被寫入文件
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999
沒有將 small 張量中的五個值保存到“small.pt”,而是保存並加載了它與 large 共享的存儲中的 999 個值。
當保存的張量元素少於其存儲對象時,可以通過首先克隆張量來減小保存的文件的大小。克隆張量會生成一個新的張量,其中包含一個新的存儲對象,該對象僅包含張量中的值
>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5
然而,由於克隆的張量彼此獨立,因此它們不具有原始張量所具有的任何視圖關係。如果在保存小於其存儲對象的張量時,文件大小和視圖關係都很重要,那麼在保存之前,必須小心地構造新的張量,以最小化其存儲對象的大小,但仍然具有所需的視圖關係。
保存和加載 torch.nn.Modules¶
另見:教學:保存和加載模組
在 PyTorch 中,模組的狀態經常使用“狀態字典”進行序列化。模組的狀態字典包含其所有參數和持久緩衝區
>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]
>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]
>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])
出於兼容性原因,建議不要直接保存模組,而是僅保存其狀態字典。Python 模組甚至有一個函數 load_state_dict(),用於從狀態字典恢復其狀態
>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>
請注意,狀態字典首先使用 torch.load() 從其文件中加載,然後使用 load_state_dict() 恢復狀態。
即使是自訂模組和包含其他模組的模組也具有狀態字典,並且可以使用這種模式
# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)
      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)
>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])
>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>
`torch.save` 的序列化文件格式¶
從 PyTorch 1.6.0 開始,除非用戶設置 _use_new_zipfile_serialization=False,否則 torch.save 默認返回未壓縮的 ZIP64 存檔。
在此存檔中,文件的順序如下
checkpoint.pth
├── data.pkl
├── byteorder  # added in PyTorch 2.1.0
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
- 條目如下
- data.pkl是將傳遞給- torch.save的對象(不包括它包含的- torch.Storage對象)進行 pickle 的結果
- byteorder包含一個字符串,其中包含保存時的- sys.byteorder(“little”或“big”)
- data/包含對象中的所有存儲,其中每個存儲都是一個單獨的文件
- version包含儲存時的版本號,可在載入時使用
 
儲存時,PyTorch 會確保每個檔案的本地檔案標頭都被填充到 64 位元組的倍數偏移量,確保每個檔案的偏移量都是 64 位元組對齊。
備註
某些裝置(例如 XLA)上的張量會被序列化為 pickled 的 numpy 陣列。 因此,它們的儲存體不會被序列化。 在這些情況下,檢查點中可能不存在 data/。
序列化 torch.nn.Modules 並在 C++ 中載入¶
另請參閱:教學課程:在 C++ 中載入 TorchScript 模型
ScriptModules 可以序列化為 TorchScript 程式,並使用 torch.jit.load() 載入。 此序列化編碼了所有模組的方法、子模組、參數和屬性,並且允許在 C++ 中載入序列化程式(即,沒有 Python)。
torch.jit.save() 和 torch.save() 之間的區別可能並非一目了然。 torch.save() 使用 pickle 儲存 Python 物件。 這對於原型設計、研究和訓練特別有用。 另一方面,torch.jit.save() 將 ScriptModules 序列化為可以在 Python 或 C++ 中載入的格式。 這在儲存和載入 C++ 模組時,或者在使用 C++ 運行在 Python 中訓練的模組時非常有用,這是部署 PyTorch 模型時的常見做法。
要在 Python 中腳本化、序列化和載入模組
>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
                      (l0): RecursiveScriptModule(original_name=Linear)
                      (l1): RecursiveScriptModule(original_name=Linear) )
追蹤的模組也可以使用 torch.jit.save() 儲存,但需要注意的是,只有追蹤的程式碼路徑會被序列化。 以下範例說明了這一點
# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)
      def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)
>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)
>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)
上述模組有一個 if 語句,該語句不會被追蹤的輸入觸發,因此不是追蹤模組的一部分,也不會與其一起序列化。 但是,腳本化的模組包含 if 語句,並且與其一起序列化。 有關腳本化和追蹤的更多資訊,請參閱 TorchScript 文件。
最後,要在 C++ 中載入模組
>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');
有關如何在 C++ 中使用 PyTorch 模組的詳細資訊,請參閱 PyTorch C++ API 文件。
跨 PyTorch 版本儲存和載入 ScriptModules¶
PyTorch 團隊建議使用相同版本的 PyTorch 儲存和載入模組。 舊版本的 PyTorch 可能不支援較新的模組,而較新版本可能已移除或修改了舊的行為。 這些變更在 PyTorch 的 發行說明 中有明確說明,並且依賴於已變更功能的模組可能需要更新才能繼續正常運作。 在有限的情況下(如下所述),PyTorch 將保留序列化 ScriptModules 的歷史行為,因此它們不需要更新。
torch.div 執行整數除法¶
在 PyTorch 1.5 和更早版本中,當給定兩個整數輸入時,torch.div() 會執行向下取整除法
# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)
但是,在 PyTorch 1.7 中,torch.div() 將始終執行其輸入的真實除法,就像 Python 3 中的除法一樣
# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)
torch.div() 的行為在序列化 ScriptModules 中得以保留。 也就是說,使用 1.6 之前版本的 PyTorch 序列化的 ScriptModules 將繼續在給定兩個整數輸入時看到 torch.div() 執行向下取整除法,即使使用較新版本的 PyTorch 載入也是如此。 但是,使用 torch.div() 並在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 無法在早期版本的 PyTorch 中載入,因為這些早期版本不理解新的行為。
torch.full 總是推斷出 float dtype¶
在 PyTorch 1.5 和更早版本中,torch.full() 總是返回一個 float 張量,而不管給定它的填充值是什麼
# PyTorch 1.5 and earlier
>>> torch.full((3,), 1)  # Note the integer fill value...
tensor([1., 1., 1.])     # ...but float tensor!
但是,在 PyTorch 1.7 中,torch.full() 將從填充值推斷返回張量的 dtype
# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])
>>> torch.full((3,), True)
tensor([True, True, True])
>>> torch.full((3,), 1.)
tensor([1., 1., 1.])
>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])
torch.full() 的行為在序列化 ScriptModules 中得以保留。 也就是說,使用 1.6 之前版本的 PyTorch 序列化的 ScriptModules 將繼續看到 torch.full 預設返回 float 張量,即使給定 bool 或整數填充值也是如此。 但是,使用 torch.full() 並在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 無法在早期版本的 PyTorch 中載入,因為這些早期版本不理解新的行為。
實用函數¶
以下實用函數與序列化相關
- torch.serialization.register_package(priority, tagger, deserializer)[原始碼]¶
- 註冊可調用函數,用於標記和反序列化儲存體物件以及關聯的優先級。 標記在儲存時將裝置與儲存體物件關聯,而反序列化則在載入時將儲存體物件移動到適當的裝置。 - tagger和- deserializer按照其- priority給定的順序運行,直到標記器/反序列化器返回的值不是 None。- 要覆蓋全局註冊表中裝置的反序列化行為,可以使用比現有標記器更高的優先級註冊標記器。 - 此函數還可以用於為新裝置註冊標記器和反序列化器。 - 參數
- priority (int) – 表示與標記器和反序列化器關聯的優先級,其中較低的值表示較高的優先級。 
- tagger (Callable[[Union[Storage, TypedStorage, UntypedStorage]], Optional[str]]) – 可調用函數,接收儲存體物件並返回其標記的裝置作為字串或 None。 
- deserializer (Callable[[Union[Storage, TypedStorage, UntypedStorage]], str], Optional[Union[Storage, TypedStorage, UntypedStorage]]]) – 可調用函數,接收儲存體物件和裝置字串,並在適當的裝置上返回儲存體物件或 None。 
 
- 返回
- 無 
 - 範例 - >>> def ipu_tag(obj): >>> if obj.device.type == 'ipu': >>> return 'ipu' >>> def ipu_deserialize(obj, location): >>> if location.startswith('ipu'): >>> ipu = getattr(torch, "ipu", None) >>> assert ipu is not None, "IPU device module is not loaded" >>> assert torch.ipu.is_available(), "ipu is not available" >>> return obj.ipu(location) >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) 
- torch.serialization.get_default_load_endianness()[原始碼]¶
- 取得載入檔案的後備位元組順序 - 如果儲存的檢查點中不存在位元組順序標記,則使用此位元組順序作為後備。 預設情況下,它是「原生」位元組順序。 - 返回
- Optional[LoadEndianness] 
- 返回類型
- default_load_endian 
 
- torch.serialization.set_default_load_endianness(endianness)[原始碼]¶
- 設定載入檔案的後備位元組順序 - 如果儲存的檢查點中不存在位元組順序標記,則使用此位元組順序作為後備。 預設情況下,它是「原生」位元組順序。 - 參數
- endianness – 新的後備位元組順序 
 
- torch.serialization.get_default_mmap_options()[source]¶
- 取得 - mmap=True時- torch.load()的預設 mmap 選項。- 預設為 - mmap.MAP_PRIVATE。- 返回
- int 
- 返回類型
- default_mmap_options 
 
- torch.serialization.set_default_mmap_options(flags)[source]¶
- 將 - mmap=True時- torch.load()的預設 mmap 選項設定為 flags。- 目前僅支援 - mmap.MAP_PRIVATE或- mmap.MAP_SHARED。如果您需要新增任何其他選項,請提出 issue。- 備註 - Windows 目前不支援此功能。 - 參數
- flags (int) – - mmap.MAP_PRIVATE或- mmap.MAP_SHARED
 
- torch.serialization.add_safe_globals(safe_globals)[source]¶
- 將指定的全域變數標記為對 - weights_only載入是安全的。例如,新增到此清單中的函數可以在反序列化期間被呼叫,類別可以被實例化並設定狀態。- 參數
- safe_globals (List[Any]) – 要標記為安全的全域變數清單 
 - 範例 - >>> import tempfile >>> class MyTensor(torch.Tensor): ... pass >>> t = MyTensor(torch.randn(2, 3)) >>> with tempfile.NamedTemporaryFile() as f: ... torch.save(t, f.name) # Running `torch.load(f.name, weights_only=True)` will fail with # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. ... torch.serialization.add_safe_globals([MyTensor]) ... torch.load(f.name, weights_only=True) # MyTensor([[-0.5024, -1.8152, -0.5455], # [-0.8234, 2.0500, -0.3657]])