快捷方式

序列化語義

本說明介紹瞭如何在 Python 中儲存和載入 PyTorch 張量和模組狀態,以及如何序列化 Python 模組以便在 C++ 中載入它們。

儲存和載入張量

torch.save()torch.load() 使您能夠輕鬆儲存和載入張量

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

按照慣例,PyTorch 檔案通常使用 '.pt' 或 '.pth' 副檔名。

torch.save()torch.load() 預設使用 Python 的 pickle,因此您也可以將多個張量儲存為元組、列表和字典等 Python 物件的一部分

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

包含 PyTorch 張量的自定義資料結構,如果該資料結構是可 pickle 化的,也可以被儲存。

儲存和載入張量保留檢視

儲存張量會保留它們的檢視關係

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

在幕後,這些張量共享相同的“儲存 (storage)”。請參閱 張量檢視 (Tensor Views) 以瞭解有關檢視和儲存的更多資訊。

當 PyTorch 儲存張量時,它會單獨儲存它們的儲存物件 (storage objects) 和張量元資料。這是一個實現細節,將來可能會發生變化,但它通常可以節省空間,並讓 PyTorch 輕鬆重建載入的張量之間的檢視關係。例如,在上面的程式碼片段中,只有一個儲存被寫入 'tensors.pt'。

然而,在某些情況下,儲存當前的儲存物件可能是沒有必要的,並且會建立過大的檔案。在下面的程式碼片段中,一個遠大於所儲存張量的儲存被寫入到一個檔案

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

儲存到 'small.pt' 檔案中的,不是 small 張量中的五個值,而是它與 large 共享的儲存中的 999 個值被儲存和載入了。

當儲存的張量包含的元素少於其儲存物件中的元素時,可以透過先複製 (cloning) 張量來減小儲存檔案的大小。複製張量會生成一個新的張量,該張量擁有一個新的儲存物件,僅包含該張量中的值

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

然而,由於複製的張量是相互獨立的,它們不具有原始張量之間的檢視關係。如果儲存小於其儲存物件的張量時,檔案大小和檢視關係都很重要,那麼必須在儲存之前仔細構建新的張量,以儘量減小其儲存物件的大小,但仍保留所需的檢視關係。

儲存和載入 torch.nn.Modules

另請參閱:教程:儲存和載入模組

在 PyTorch 中,模組的狀態通常使用“狀態字典 (state dict)”進行序列化。模組的狀態字典包含其所有引數和持久緩衝區 (persistent buffers)

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

出於相容性原因,建議不要直接儲存模組,而是隻儲存其狀態字典。Python 模組甚至有一個函式 load_state_dict(),用於從狀態字典恢復其狀態

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>

請注意,狀態字典首先使用 torch.load() 從檔案中載入,然後使用 load_state_dict() 恢復狀態。

即使是自定義模組和包含其他模組的模組也具有狀態字典,並且可以使用此模式

# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

torch.save 的序列化檔案格式

自 PyTorch 1.6.0 版本起,除非使用者設定 _use_new_zipfile_serialization=False,否則 torch.save 預設返回未壓縮的 ZIP64 歸檔檔案。

在此歸檔檔案中,檔案按如下順序排列

checkpoint.pth
├── data.pkl
├── byteorder  # added in PyTorch 2.1.0
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
條目如下
  • data.pkl 是對傳遞給 torch.save 的物件進行 pickle 化的結果,其中不包含物件內部的 torch.Storage 物件

  • byteorder 包含一個字串,表示儲存時的 sys.byteorder(“little” 或 “big”)

  • data/ 包含物件中的所有儲存,其中每個儲存都是一個單獨的檔案

  • version 包含儲存時的版本號,可在載入時使用

儲存時,PyTorch 將確保每個檔案的本地檔案頭填充到 64 位元組的倍數偏移量,從而確保每個檔案的偏移量是 64 位元組對齊的。

注意

某些裝置(如 XLA)上的張量被序列化為 pickled 的 numpy 陣列。因此,它們的儲存不會被序列化。在這種情況下,檢查點中可能不存在 data/

weights_only=Truetorch.load

從 2.6 版本開始,如果未傳遞 pickle_module 引數,torch.load 將使用 weights_only=True

正如 torch.load() 文件中所討論的,weights_only=Truetorch.load 中使用的 unpickler 限制為僅執行普通 torch.Tensorsstate_dicts 以及其他一些原始型別所需的函式/構建類。此外,與 pickle 模組提供的預設 Unpickler 不同,weights_only Unpickler 在 unpickling 過程中不允許動態匯入任何內容。

如上所述,在使用 torch.save 時,儲存模組的 state_dict 是一個最佳實踐。如果載入包含 nn.Module 的舊檢查點,我們建議使用 weights_only=False。載入包含張量子類的檢查點時,很可能需要將某些函式/類新增到允許列表中,詳情見下文。

如果 weights_only Unpickler 在 pickle 檔案中遇到預設情況下未新增到允許列表的函式或類,您應該會看到類似以下的可操作錯誤

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
    1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
        but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
    2. Alternatively, to load with `weights_only=True` please check the recommended
       steps in the following error message.
       WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
       default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
       `torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
       if you trust this class/function.

請按照錯誤訊息中的步驟,僅在您信任這些函式或類時,才將它們新增到允許列表中。

要獲取檢查點中所有尚未新增到允許列表的 GLOBAL(函式/類),您可以使用 torch.serialization.get_unsafe_globals_in_checkpoint(),它將返回一個形式為 {__module__}.{__name__} 的字串列表。如果您信任這些函式/類,您可以匯入它們,並按照錯誤訊息的指示,透過 torch.serialization.add_safe_globals() 或上下文管理器 torch.serialization.safe_globals 將它們新增到允許列表中。

要訪問使用者新增到允許列表的函式/類列表,您可以使用 torch.serialization.get_safe_globals();要清除當前列表,請參閱 torch.serialization.clear_safe_globals()

排除 weights_only 故障

獲取不安全的全域性變數

需要注意的是,torch.serialization.get_unsafe_globals_in_checkpoint() 是靜態分析檢查點,某些型別可能在 unpickling 過程中動態構建,因此不會被 torch.serialization.get_unsafe_globals_in_checkpoint() 報告。一個這樣的例子是 numpy 中的 dtypes。在 numpy < 1.25 中,將 torch.serialization.get_unsafe_globals_in_checkpoint() 報告的所有函式/類新增到允許列表後,您可能會看到類似以下的錯誤

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>

這可以透過 {add_}safe_globals([type(np.dtype(np.float32))]) 新增到允許列表。

numpy >=1.25 中,您會看到

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>

這可以透過 {add_}safe_globals([np.dtypes.Float32DType]) 新增到允許列表。

環境變數

有兩個環境變數會影響 torch.load 的行為。如果您無法訪問 torch.load 呼叫點,這些變數會很有幫助。

  • TORCH_FORCE_WEIGHTS_ONLY_LOAD=1 將覆蓋所有 torch.load 呼叫點,使其使用 weights_only=True

  • TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 將使 torch.load 呼叫點僅在未將 weights_only 作為引數傳遞時才使用 weights_only=False

序列化 torch.nn.Modules 並在 C++ 中載入它們

另請參閱:教程:在 C++ 中載入 TorchScript 模型

ScriptModules 可以序列化為 TorchScript 程式,並使用 torch.jit.load() 載入。這種序列化編碼了模組的所有方法、子模組、引數和屬性,並且允許序列化的程式在 C++ 中載入(即無需 Python 環境)。

torch.jit.save()torch.save() 之間的區別可能不是立即可見的。torch.save() 使用 pickle 儲存 Python 物件。這對於原型開發、研究和訓練特別有用。torch.jit.save() 則將 ScriptModules 序列化為可以在 Python 或 C++ 中載入的格式。這在儲存和載入 C++ 模組或使用 C++ 執行在 Python 中訓練的模組時非常有用,這是部署 PyTorch 模型時的常見做法。

在 Python 中進行指令碼化、序列化和載入模組

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
                      (l0): RecursiveScriptModule(original_name=Linear)
                      (l1): RecursiveScriptModule(original_name=Linear) )

跟蹤模組也可以使用 torch.jit.save() 儲存,但需要注意的是,只序列化跟蹤到的程式碼路徑。以下示例演示了這一點

# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

上面的模組有一個 if 語句,該語句未被跟蹤的輸入觸發,因此它不是跟蹤模組的一部分,也不會隨之序列化。然而,指令碼化的模組包含該 if 語句並隨之序列化。有關指令碼化和跟蹤的更多資訊,請參閱 TorchScript 文件

最後,在 C++ 中載入模組

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

有關如何在 C++ 中使用 PyTorch 模組的詳細資訊,請參閱 PyTorch C++ API 文件

跨 PyTorch 版本儲存和載入 ScriptModules

PyTorch 團隊建議使用相同版本的 PyTorch 儲存和載入模組。較舊的 PyTorch 版本可能不支援較新的模組,而較新的版本可能已刪除或修改了舊的行為。這些更改在 PyTorch 的 釋出說明 中有明確描述,依賴已更改功能的模組可能需要更新才能繼續正常工作。在下面詳細介紹的有限情況下,PyTorch 將保留序列化 ScriptModules 的歷史行為,這樣它們就不需要更新。

torch.div 執行整數除法

在 PyTorch 1.5 及更早版本中,當給定兩個整數輸入時,torch.div() 會執行地板除法 (floor division)

# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

然而,在 PyTorch 1.7 中,torch.div() 將始終對其輸入執行真除法 (true division),就像 Python 3 中的除法一樣

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

torch.div() 的行為在序列化的 ScriptModules 中被保留。也就是說,即使使用較新版本的 PyTorch 載入,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 在給定兩個整數輸入時仍將看到 torch.div() 執行地板除法。然而,在 PyTorch 1.6 及更高版本上使用 torch.div() 並序列化的 ScriptModules 不能在更早的 PyTorch 版本中載入,因為這些更早的版本無法理解新行為。

torch.full 總是推斷為浮點型 dtype

在 PyTorch 1.5 及更早版本中,無論給定什麼填充值,torch.full() 總是返回一個浮點型張量

# PyTorch 1.5 and earlier
>>> torch.full((3,), 1)  # Note the integer fill value...
tensor([1., 1., 1.])     # ...but float tensor!

然而,在 PyTorch 1.7 中,torch.full() 將從填充值推斷返回張量的 dtype

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])

torch.full() 的行為在序列化的 ScriptModules 中被保留。也就是說,即使給定 bool 或整數填充值,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 預設仍將看到 torch.full 返回浮點型張量。然而,在 PyTorch 1.6 及更高版本上使用 torch.full() 並序列化的 ScriptModules 不能在更早的 PyTorch 版本中載入,因為這些更早的版本無法理解新行為。

實用函式

以下實用函式與序列化相關

torch.serialization.register_package(priority, tagger, deserializer)[源][源]

註冊用於標記和反序列化儲存物件的可呼叫物件,並關聯優先順序。標記 (tagging) 在儲存時將裝置與儲存物件關聯,而反序列化 (deserializing) 在載入時將儲存物件移動到適當的裝置。taggerdeserializer 按照其 priority 給定的順序執行,直到 tagger/deserializer 返回一個非 None 的值。

要覆蓋全域性登錄檔中某個裝置的反序列化行為,可以註冊一個優先順序高於現有 tagger 的 tagger。

此函式還可用於為新設備註冊 tagger 和 deserializer。

引數
返回

None

示例

>>> def ipu_tag(obj):
>>>     if obj.device.type == 'ipu':
>>>         return 'ipu'
>>> def ipu_deserialize(obj, location):
>>>     if location.startswith('ipu'):
>>>         ipu = getattr(torch, "ipu", None)
>>>         assert ipu is not None, "IPU device module is not loaded"
>>>         assert torch.ipu.is_available(), "ipu is not available"
>>>         return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
torch.serialization.get_crc32_options()[source][source]

獲取 torch.save() 是否計算併為每個記錄寫入 crc32。

預設為 True

返回型別

bool

torch.serialization.set_crc32_options(compute_crc32)[source][source]

設定 torch.save() 是否計算併為每個記錄寫入 crc32。

注意

將其設定為 False 可能會導致解壓 torch.save 輸出時因 CRC32 損壞而失敗或發出警告。但 torch.load 將能夠載入該檔案。

引數

compute_crc32 (bool) – 設定 crc32 計算標誌

torch.serialization.get_default_load_endianness()[source][source]

獲取載入檔案的回退位元組順序

如果儲存的檢查點中不存在位元組順序標記,則使用此位元組順序作為回退。預設情況下,它是“native”位元組順序。

返回

Optional[LoadEndianness]

返回型別

default_load_endian

torch.serialization.set_default_load_endianness(endianness)[source][source]

設定載入檔案的回退位元組順序

如果儲存的檢查點中不存在位元組順序標記,則使用此位元組順序作為回退。預設情況下,它是“native”位元組順序。

引數

endianness – 新的回退位元組順序

torch.serialization.get_default_mmap_options()[source][source]

獲取 torch.load() 並設定 mmap=True 時的預設 mmap 選項。

預設為 mmap.MAP_PRIVATE

返回

int

返回型別

default_mmap_options

torch.serialization.set_default_mmap_options(flags)[source][source]

上下文管理器或函式,用於設定 torch.load() 並設定 mmap=True 時的預設 mmap 選項為 flags。

目前,僅支援 mmap.MAP_PRIVATEmmap.MAP_SHARED。如果您需要新增任何其他選項,請提交 issue。

注意

此功能目前不支援 Windows。

引數

flags (int) – mmap.MAP_PRIVATEmmap.MAP_SHARED

torch.serialization.add_safe_globals(safe_globals)[source][source]

將給定的全域性變數標記為對 weights_only 載入安全。例如,新增到此列表中的函式可以在反序列化時被呼叫,類可以被例項化並設定狀態。

列表中的每個項可以是函式/類本身,也可以是形式為 (函式/類, 字串) 的元組,其中字串是函式/類的完整路徑。

在序列化格式中,每個函式都由其完整路徑 {__module__}.{__qualname__} 標識。呼叫此 API 時,您可以提供應該與檢查點中匹配的完整路徑,否則將使用預設的 {fn.__module__}.{fn.__qualname__}

引數

safe_globals (List[Union[Callable, Tuple[Callable, str]]]) – 要標記為安全的全域性變數列表

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     torch.serialization.add_safe_globals([MyTensor])
...     torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
torch.serialization.clear_safe_globals()[source][source]

清除對 weights_only 載入安全的全域性變數列表。

torch.serialization.get_safe_globals()[source][source]

返回使用者新增的對 weights_only 載入安全的全域性變數列表。

返回型別

list[Union[Callable, tuple[Callable, str]]]

torch.serialization.get_unsafe_globals_in_checkpoint(f)[source][source]

返回一個字串列表,包含 torch.save 物件中對 weights_only 不安全的函式/類。

對於給定的函式或類 f,對應的字串將是 {f.__module__}.{f.__name__} 的形式。

此函式將返回檢查點中未標記為對 weights_only 安全的任何全域性變數(無論是透過 add_safe_globals()safe_globals 上下文,還是由 torch 預設列入白名單)。

注意

此函式將靜態反彙編檢查點中的 pickle 檔案。這意味著在反序列化期間動態推送到棧上的任何類將不包含在輸出中。

引數

f (Union[str, PathLike[str], IO[bytes]]) – 檔案類物件或包含透過 torch.save 儲存的檢查點物件的字串

返回

檢查點中未列入 weights_only 白名單的 pickle 全域性變數字串列表。

返回型別

list[str]

class torch.serialization.safe_globals(safe_globals)[source][source]

上下文管理器,將某些全域性變數新增為對 weights_only 載入安全。

引數

safe_globals (list[Union[Callable, tuple[Callable, str]]]) – 用於 weights_only 載入的全域性變數列表。

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     with torch.serialization.safe_globals([MyTensor]):
...         torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
>>> assert torch.serialization.get_safe_globals() == []
class torch.serialization.skip_data(materialize_fake_tensors=False)[source][source]

上下文管理器,用於在 torch.save / torch.load 呼叫時跳過寫入/讀取儲存位元組。

對於儲存路徑,儲存仍將儲存,但通常寫入其位元組的空間將是空閒空間。隨後可以在單獨的過程中填充儲存位元組。

對於載入路徑,張量將按檢查點載入,但它們的儲存不會填充資料。

警告

skip_data 上下文管理器是早期原型,可能會發生變化。

引數

materialize_fake_tensors (bool) – 是否在儲存期間具體化 FakeTensors。這對於載入路徑是空操作。

示例

>>> import tempfile
>>> t = torch.randn(2, 3)
>>> with tempfile.NamedTemporaryFile() as f:
...     with torch.serialization.skip_data():
...         torch.save(t, f.name)
...     torch.load(f.name, weights_only=True)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

配置

torch.utils.serialization.config 提供了一個全域性配置,可以控制 torch.savetorch.load 的行為。

torch.utils.serialization.config.save 包含控制 torch.save 行為的選項。

  • compute_crc32:是否計算並寫入 zip 檔案校驗和(預設值:True)。參見 set_crc32_options()

  • use_pinned_memory_for_d2h:對於傳遞給 torch.save 時位於加速器上的儲存,是否在 torch.save 內將儲存移動到 CPU 上的鎖定記憶體或可分頁記憶體(預設值:False(即可分頁))。

  • storage_alignment:在 torch.save 期間檢查點中儲存的對齊方式(以位元組為單位)。(預設值 64

torch.utils.serialization.config.load 包含控制 torch.load 行為的選項。

  • mmap:參見 torch.load()mmap 引數的文件。如果未顯式傳遞給 torch.load 呼叫,此配置將設定 mmap 對於 torch.load 的行為(預設值:False)。

  • endianness:參見 set_default_load_endianness()。(預設值:torch.serialization.LoadEndianness.NATIVE

  • mmap_flags:參見 set_default_mmap_options。(預設值:MAP_PRIVATE

  • calculate_storage_offsets:如果此配置設定為 True,則在使用 torch.load(mmap=True) 時將計算儲存的偏移量,而不是透過隨機讀取來讀取。這最大程度地減少了隨機讀取,當檔案透過網路載入時會很有幫助。(預設值:False

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源