快捷方式

torch.load

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)[原始碼][原始碼]

從檔案載入使用 torch.save() 儲存的物件。

torch.load() 使用 Python 的反序列化功能(unpickling facilities),但對儲存(storages,張量的底層資料結構)進行特殊處理。它們首先在 CPU 上反序列化,然後被移動到儲存時所在的裝置。如果此操作失敗(例如,因為執行時系統不具備某些裝置),則會引發異常。但是,可以使用 map_location 引數將儲存動態地重新對映到另一組裝置。

如果 map_location 是一個可呼叫物件 (callable),它將為每個序列化的儲存 (storage) 呼叫一次,並傳入兩個引數:儲存物件和位置標籤。儲存物件引數是儲存的初始反序列化結果,位於 CPU 上。每個序列化的儲存都有一個與之關聯的位置標籤,用於標識它儲存時所在的裝置,此標籤是傳遞給 map_location 的第二個引數。內建的位置標籤包括用於 CPU 張量的 'cpu' 以及用於 CUDA 張量的 'cuda:device_id'(例如 'cuda:2')。map_location 應該返回 None 或一個儲存物件。如果 map_location 返回一個儲存物件,它將用作最終的反序列化物件,該物件已經移動到正確的裝置。否則,torch.load() 將回退到預設行為,就像沒有指定 map_location 一樣。

如果 map_location 是一個 torch.device 物件或包含裝置標籤的字串,它表示所有張量應該載入到的位置。

否則,如果 map_location 是一個 dict,它將用於將檔案中出現的位置標籤(鍵)重新對映到指定儲存放置位置的標籤(值)。

使用者擴充套件可以使用 torch.serialization.register_package() 註冊自己的位置標籤以及標記和反序列化方法。

引數
  • f (Union[str, PathLike[str], IO[bytes]]) – 類檔案物件(必須實現 read(), readline(), tell()seek() 方法),或包含檔名的字串或 os.PathLike 物件

  • map_location (Optional[Union[Callable[[Storage, str], Storage], device, str, dict[str, str]]]) – 函式、torch.device 物件、字串或 dict,指定如何重新對映儲存位置

  • pickle_module (Optional[Any]) – 用於反序列化元資料和物件的模組(必須與序列化檔案時使用的 pickle_module 匹配)

  • weights_only (Optional[bool]) – 指示反序列化器是否應僅限於載入張量、基本型別、字典以及透過 torch.serialization.add_safe_globals() 新增的任何型別。有關更多詳細資訊,請參見torch.load with weights_only=True

  • mmap (Optional[bool]) – 指示檔案是否應進行記憶體對映(mmap),而不是將所有儲存載入到記憶體中。通常,檔案中的張量儲存首先會從磁碟移動到 CPU 記憶體,然後根據儲存時的標籤或由 map_location 指定的位置移動到最終裝置。如果最終位置是 CPU,則第二步是空操作。當設定 mmap 標誌時,第一步會將 f 進行記憶體對映,而不是將張量儲存從磁碟複製到 CPU 記憶體。

  • pickle_load_args (Any) – (僅限 Python 3)傳遞給 pickle_module.load()pickle_module.Unpickler() 的可選關鍵字引數,例如 errors=...

返回型別

Any

警告

torch.load() 除非將 weights_only 引數設定為 True,否則會隱式使用 pickle 模組,該模組已知存在安全風險。惡意構造的 pickle 資料在反序列化過程中可能執行任意程式碼。切勿在非安全模式下載入來自不受信任來源或可能被篡改的資料。僅載入您信任的資料。

注意

當您對包含 GPU 張量的檔案呼叫 torch.load() 時,這些張量將預設載入到 GPU。您可以透過呼叫 torch.load(.., map_location='cpu') 然後呼叫 load_state_dict() 來避免載入模型檢查點時 GPU 記憶體驟增。

注意

預設情況下,我們將位元組字串解碼為 utf-8。這是為了避免在 Python 3 中載入 Python 2 儲存的檔案時常見的 UnicodeDecodeError: 'ascii' codec can't decode byte 0x... 錯誤。如果此預設設定不正確,您可以使用額外的 encoding 關鍵字引數來指定應如何載入這些物件,例如,encoding='latin1' 會使用 latin1 編碼將其解碼為字串,而 encoding='bytes' 會將其保留為位元組陣列,稍後可以使用 byte_array.decode(...) 進行解碼。

示例

>>> torch.load("tensors.pt", weights_only=True)
# Load all tensors onto the CPU
>>> torch.load(
...     "tensors.pt",
...     map_location=torch.device("cpu"),
...     weights_only=True,
... )
# Load all tensors onto the CPU, using a function
>>> torch.load(
...     "tensors.pt",
...     map_location=lambda storage, loc: storage,
...     weights_only=True,
... )
# Load all tensors onto GPU 1
>>> torch.load(
...     "tensors.pt",
...     map_location=lambda storage, loc: storage.cuda(1),
...     weights_only=True,
... )  # type: ignore[attr-defined]
# Map tensors from GPU 1 to GPU 0
>>> torch.load(
...     "tensors.pt",
...     map_location={"cuda:1": "cuda:0"},
...     weights_only=True,
... )
# Load tensor from io.BytesIO object
# Loading from a buffer setting weights_only=False, warning this can be unsafe
>>> with open("tensor.pt", "rb") as f:
...     buffer = io.BytesIO(f.read())
>>> torch.load(buffer, weights_only=False)
# Load a module with 'ascii' encoding for unpickling
# Loading from a module setting weights_only=False, warning this can be unsafe
>>> torch.load("module.pt", encoding="ascii", weights_only=False)

文件

訪問 PyTorch 的全面開發者文件

檢視文件

教程

獲取針對初學者和高階開發者的深度教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源