torch.load¶
- torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)[原始碼][原始碼]¶
從檔案載入使用
torch.save()儲存的物件。torch.load()使用 Python 的反序列化功能(unpickling facilities),但對儲存(storages,張量的底層資料結構)進行特殊處理。它們首先在 CPU 上反序列化,然後被移動到儲存時所在的裝置。如果此操作失敗(例如,因為執行時系統不具備某些裝置),則會引發異常。但是,可以使用map_location引數將儲存動態地重新對映到另一組裝置。如果
map_location是一個可呼叫物件 (callable),它將為每個序列化的儲存 (storage) 呼叫一次,並傳入兩個引數:儲存物件和位置標籤。儲存物件引數是儲存的初始反序列化結果,位於 CPU 上。每個序列化的儲存都有一個與之關聯的位置標籤,用於標識它儲存時所在的裝置,此標籤是傳遞給map_location的第二個引數。內建的位置標籤包括用於 CPU 張量的'cpu'以及用於 CUDA 張量的'cuda:device_id'(例如'cuda:2')。map_location應該返回None或一個儲存物件。如果map_location返回一個儲存物件,它將用作最終的反序列化物件,該物件已經移動到正確的裝置。否則,torch.load()將回退到預設行為,就像沒有指定map_location一樣。如果
map_location是一個torch.device物件或包含裝置標籤的字串,它表示所有張量應該載入到的位置。否則,如果
map_location是一個 dict,它將用於將檔案中出現的位置標籤(鍵)重新對映到指定儲存放置位置的標籤(值)。使用者擴充套件可以使用
torch.serialization.register_package()註冊自己的位置標籤以及標記和反序列化方法。- 引數
f (Union[str, PathLike[str], IO[bytes]]) – 類檔案物件(必須實現
read(),readline(),tell()和seek()方法),或包含檔名的字串或 os.PathLike 物件map_location (Optional[Union[Callable[[Storage, str], Storage], device, str, dict[str, str]]]) – 函式、
torch.device物件、字串或 dict,指定如何重新對映儲存位置pickle_module (Optional[Any]) – 用於反序列化元資料和物件的模組(必須與序列化檔案時使用的
pickle_module匹配)weights_only (Optional[bool]) – 指示反序列化器是否應僅限於載入張量、基本型別、字典以及透過
torch.serialization.add_safe_globals()新增的任何型別。有關更多詳細資訊,請參見torch.load with weights_only=True。mmap (Optional[bool]) – 指示檔案是否應進行記憶體對映(mmap),而不是將所有儲存載入到記憶體中。通常,檔案中的張量儲存首先會從磁碟移動到 CPU 記憶體,然後根據儲存時的標籤或由
map_location指定的位置移動到最終裝置。如果最終位置是 CPU,則第二步是空操作。當設定mmap標誌時,第一步會將f進行記憶體對映,而不是將張量儲存從磁碟複製到 CPU 記憶體。pickle_load_args (Any) – (僅限 Python 3)傳遞給
pickle_module.load()和pickle_module.Unpickler()的可選關鍵字引數,例如errors=...。
- 返回型別
警告
torch.load()除非將weights_only引數設定為True,否則會隱式使用pickle模組,該模組已知存在安全風險。惡意構造的 pickle 資料在反序列化過程中可能執行任意程式碼。切勿在非安全模式下載入來自不受信任來源或可能被篡改的資料。僅載入您信任的資料。注意
當您對包含 GPU 張量的檔案呼叫
torch.load()時,這些張量將預設載入到 GPU。您可以透過呼叫torch.load(.., map_location='cpu')然後呼叫load_state_dict()來避免載入模型檢查點時 GPU 記憶體驟增。注意
預設情況下,我們將位元組字串解碼為
utf-8。這是為了避免在 Python 3 中載入 Python 2 儲存的檔案時常見的UnicodeDecodeError: 'ascii' codec can't decode byte 0x...錯誤。如果此預設設定不正確,您可以使用額外的encoding關鍵字引數來指定應如何載入這些物件,例如,encoding='latin1'會使用latin1編碼將其解碼為字串,而encoding='bytes'會將其保留為位元組陣列,稍後可以使用byte_array.decode(...)進行解碼。示例
>>> torch.load("tensors.pt", weights_only=True) # Load all tensors onto the CPU >>> torch.load( ... "tensors.pt", ... map_location=torch.device("cpu"), ... weights_only=True, ... ) # Load all tensors onto the CPU, using a function >>> torch.load( ... "tensors.pt", ... map_location=lambda storage, loc: storage, ... weights_only=True, ... ) # Load all tensors onto GPU 1 >>> torch.load( ... "tensors.pt", ... map_location=lambda storage, loc: storage.cuda(1), ... weights_only=True, ... ) # type: ignore[attr-defined] # Map tensors from GPU 1 to GPU 0 >>> torch.load( ... "tensors.pt", ... map_location={"cuda:1": "cuda:0"}, ... weights_only=True, ... ) # Load tensor from io.BytesIO object # Loading from a buffer setting weights_only=False, warning this can be unsafe >>> with open("tensor.pt", "rb") as f: ... buffer = io.BytesIO(f.read()) >>> torch.load(buffer, weights_only=False) # Load a module with 'ascii' encoding for unpickling # Loading from a module setting weights_only=False, warning this can be unsafe >>> torch.load("module.pt", encoding="ascii", weights_only=False)