torch.hub¶
Pytorch Hub 是一個預先訓練的模型儲存庫,旨在促進研究的可重複性。
發布模型¶
Pytorch Hub 支援透過新增一個簡單的 hubconf.py 檔案,將預先訓練的模型(模型定義和預先訓練的權重)發布到 GitHub 儲存庫;
hubconf.py 可以有多個入口點。每個入口點都被定義為一個 Python 函數(例如:您想要發布的預先訓練的模型)。
def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...
如何實作入口點?¶
以下程式碼片段指定了 resnet18 模型的入口點,如果我們在 pytorch/vision/hubconf.py 中擴展實作的話。在大多數情況下,在 hubconf.py 中匯入正確的函數就足夠了。在這裡,我們只想使用擴展版本作為範例來展示它是如何工作的。您可以在 pytorch/vision repo 中看到完整的腳本
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
    """ # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """
    # Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model
- dependencies變數是載入模型所需的套件名稱清單。請注意,這可能與訓練模型所需的依賴項略有不同。
- args和- kwargs會被傳遞給真正的可呼叫函數。
- 函數的說明文件可用作說明訊息。它解釋了模型的功能以及允許的位置/關鍵字參數。強烈建議在這裡新增一些範例。 
- 入口點函數可以返回一個模型 (nn.module),也可以返回輔助工具,使使用者工作流程更順暢,例如標記器。 
- 以底線開頭的可呼叫函數被視為輔助函數,不會顯示在 - torch.hub.list()中。
- 預先訓練的權重可以儲存在 GitHub 儲存庫中,也可以透過 - torch.hub.load_state_dict_from_url()載入。如果小於 2GB,建議將其附加到 專案版本 並使用版本中的網址。在上面的範例中,- torchvision.models.resnet.resnet18處理- pretrained,或者您也可以將以下邏輯放在入口點定義中。
if pretrained:
    # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)
    # For checkpoint saved elsewhere
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
重要注意事項¶
- 發布的模型至少應該在一個分支/標籤中。它不能是一個隨機的提交。 
從 Hub 載入模型¶
Pytorch Hub 提供了方便的 API,透過 torch.hub.list() 探索 hub 中所有可用的模型,透過 torch.hub.help() 顯示說明文件和範例,並使用 torch.hub.load() 載入預先訓練的模型。
- torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[來源]¶
- 列出 - github指定的儲存庫中所有可用的可呼叫入口點。- 參數
- github (str) – 格式為「repo_owner/repo_name[:ref]」的字串,其中包含一個可選的 ref(標籤或分支)。如果未指定 - ref,則預設分支假設為- main(如果存在),否則為- master。例如:'pytorch/vision:0.10'
- force_reload (bool, 選用) – 是否要捨棄現有的快取並強制重新下載。預設值為 - False。
- skip_validation (bool, 選用) – 如果為 - False,torchhub 將會檢查- github參數所指定的 branch 或 commit 是否確實屬於該 repo 所有者。這會向 GitHub API 發出請求;您可以透過設定- GITHUB_TOKEN環境變數來指定非預設的 GitHub 權杖。預設值為- False。
- trust_repo (bool, str 或 None) – - "check"、- True、- False或- None。此參數是在 v1.12 中引入的,用於確保使用者只會執行來自他們信任的 repo 的程式碼。- 如果為 - False,則會出現提示詢問使用者是否應該信任該 repo。
- 如果為 - True,則會將該 repo 加入信任清單中,並且載入時不需要明確的確認。
- 如果為 - "check",則會根據快取中的信任 repo 清單檢查該 repo。如果該 repo 不在清單中,則行為會回到- trust_repo=False選項。
- 如果為 - None:這會產生警告,邀請使用者將- trust_repo設定為- False、- True或- "check"。這只是為了向後相容性而保留的,並將在 v2.0 中移除。
 - 預設值為 - None,並將在 v2.0 中最終更改為- "check"。
- verbose (bool, 選用) – 如果為 - False,則會隱藏關於命中本地快取的訊息。請注意,關於首次下載的訊息無法隱藏。預設值為- True。
 
- 回傳值
- 可用的可呼叫進入點 
- 回傳型別
 - 範例 - >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) 
- torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[原始碼]¶
- 顯示進入點 - model的說明字串。- 參數
- github (str) – 格式為 <repo_owner/repo_name[:ref]> 的字串,其中 ref 為選用 (標籤或分支)。如果未指定 - ref,則預設分支假設為- main(如果存在),否則為- master。範例:'pytorch/vision:0.10'
- model (str) – 在 repo 的 - hubconf.py中定義的進入點名稱字串
- force_reload (bool, 選用) – 是否要捨棄現有的快取並強制重新下載。預設值為 - False。
- skip_validation (bool, 選用) – 如果為 - False,torchhub 將會檢查- github參數所指定的 ref 是否確實屬於該 repo 所有者。這會向 GitHub API 發出請求;您可以透過設定- GITHUB_TOKEN環境變數來指定非預設的 GitHub 權杖。預設值為- False。
- trust_repo (bool, str 或 None) – - "check"、- True、- False或- None。此參數是在 v1.12 中引入的,用於確保使用者只會執行來自他們信任的 repo 的程式碼。- 如果為 - False,則會出現提示詢問使用者是否應該信任該 repo。
- 如果為 - True,則會將該 repo 加入信任清單中,並且載入時不需要明確的確認。
- 如果為 - "check",則會根據快取中的信任 repo 清單檢查該 repo。如果該 repo 不在清單中,則行為會回到- trust_repo=False選項。
- 如果為 - None:這會產生警告,邀請使用者將- trust_repo設定為- False、- True或- "check"。這只是為了向後相容性而保留的,並將在 v2.0 中移除。
 - 預設值為 - None,並將在 v2.0 中最終更改為- "check"。
 
 - 範例 - >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) 
- torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[原始碼]¶
- 從 github repo 或本地目錄載入模型。 - 注意:載入模型是典型的用例,但這也可以用於載入其他物件,例如 tokenizer、損失函數等。 - 如果 - source為 'github',則預期- repo_or_dir的格式為- repo_owner/repo_name[:ref],其中 ref 為選用 (標籤或分支)。- 如果 - source為 'local',則預期- repo_or_dir是本地目錄的路徑。- 參數
- repo_or_dir (str) – 如果 - source為 'github',則這應該對應於格式為- repo_owner/repo_name[:ref]的 github repo,其中 ref 為選用 (標籤或分支),例如 'pytorch/vision:0.10'。如果未指定- ref,則預設分支假設為- main(如果存在),否則為- master。如果- source為 'local',則這應該是指向本地目錄的路徑。
- model (str) – 在 repo/dir 的 - hubconf.py中定義的可呼叫 (進入點) 名稱。
- *args (選用) – 對應於可呼叫 - model的參數。
- source (str, 選用) – 'github' 或 'local'。指定如何解譯 - repo_or_dir。預設值為 'github'。
- trust_repo (bool, str 或 None) – - "check"、- True、- False或- None。此參數是在 v1.12 中引入的,用於確保使用者只會執行來自他們信任的 repo 的程式碼。- 如果為 - False,則會出現提示詢問使用者是否應該信任該 repo。
- 如果為 - True,則會將該 repo 加入信任清單中,並且載入時不需要明確的確認。
- 如果為 - "check",則會根據快取中的信任 repo 清單檢查該 repo。如果該 repo 不在清單中,則行為會回到- trust_repo=False選項。
- 如果為 - None:這會產生警告,邀請使用者將- trust_repo設定為- False、- True或- "check"。這只是為了向後相容性而保留的,並將在 v2.0 中移除。
 - 預設值為 - None,並將在 v2.0 中最終更改為- "check"。
- force_reload (bool, 選用) – 是否要強制無條件地重新下載 github repo。如果 - source = 'local',則沒有任何作用。預設值為- False。
- verbose (bool, 選用) – 如果為 - False,則會隱藏關於命中本地快取的訊息。請注意,關於首次下載的訊息無法隱藏。如果- source = 'local',則沒有任何作用。預設值為- True。
- skip_validation (bool, 選用) – 如果為 - False,torchhub 將會檢查- github參數所指定的 branch 或 commit 是否確實屬於該 repo 所有者。這會向 GitHub API 發出請求;您可以透過設定- GITHUB_TOKEN環境變數來指定非預設的 GitHub 權杖。預設值為- False。
- **kwargs (選用) – 對應於可呼叫 - model的關鍵字參數。
 
- 回傳值
- 使用指定的 - *args和- **kwargs呼叫- model可呼叫的輸出。
 - 範例 - >>> # from a github repo >>> repo = 'pytorch/vision' >>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1') >>> # from a local directory >>> path = '/some/local/path/pytorch/vision' >>> model = torch.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT') 
- torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[原始碼]¶
- 將指定 URL 的物件下載到本地路徑。 - 參數
 - 範例 - >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') 
- torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[原始碼]¶
- 載入指定 URL 的 Torch 序列化物件。 - 如果下載的檔案是 zip 檔案,它會被自動解壓縮。 - 如果物件已經存在於 model_dir 中,它會被反序列化並返回。 - model_dir的預設值為- <hub_dir>/checkpoints,其中- hub_dir是由- get_dir()返回的目錄。- 參數
- url (str) – 要下載的物件的 URL 
- model_dir (str, 可選) – 儲存物件的目錄 
- map_location (可選) – 一個函數或字典,指定如何重新映射儲存位置(請參閱 torch.load) 
- progress (bool, 可選) – 是否在 stderr 顯示進度條。 預設值:True 
- check_hash (bool, 可選) – 如果為 True,則 URL 的檔名部分應遵循命名慣例 - filename-<sha256>.ext,其中- <sha256>是檔案內容的 SHA256 雜湊的前八位或更多位數。 雜湊用於確保唯一的名稱並驗證檔案的內容。 預設值:False
- file_name (str, 可選) – 下載檔案的名稱。 如果未設定,將使用 - url中的檔名。
- weights_only (bool, 可選) – 如果為 True,則只會載入權重,而不會載入複雜的序列化物件。 建議用於不受信任的來源。 有關更多詳細信息,請參閱 - load()。
 
- 回傳型別
 - 範例 - >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') 
執行載入的模型:¶
請注意,torch.hub.load() 中的 *args 和 **kwargs 用於**實例化**模型。 載入模型後,如何才能知道可以用模型做些什麼? 建議的工作流程如下:
- dir(model)查看模型的所有可用方法。
- help(model.foo)檢查- model.foo執行時需要的參數
為了幫助用戶在不來回參考文檔的情況下進行探索,我們強烈建議資源庫擁有者使函數幫助訊息清晰簡潔。 包含一個最小的工作範例也很有幫助。
我的下載模型儲存在哪裡?¶
位置按以下順序使用:
- 呼叫 - hub.set_dir(<PATH_TO_HUB_DIR>)
- $TORCH_HOME/hub,如果設定了環境變數- TORCH_HOME。
- $XDG_CACHE_HOME/torch/hub,如果設定了環境變數- XDG_CACHE_HOME。
- ~/.cache/torch/hub
快取邏輯¶
預設情況下,我們不會在載入檔案後清理檔案。 如果快取已存在於 get_dir() 返回的目錄中,Hub 預設會使用該快取。
用戶可以通過呼叫 hub.load(..., force_reload=True) 來強制重新載入。 這將刪除現有的 GitHub 資料夾和下載的權重,重新初始化一個新的下載。 當發佈到同一個分支的更新時,這很有用,用戶可以保持最新版本。
已知限制:¶
Torch hub 的工作原理是導入套件,就好像它已經被安裝一樣。 在 Python 中導入會產生一些副作用。 例如,您可以在 Python 快取 sys.modules 和 sys.path_importer_cache 中看到新項目,這是正常的 Python 行為。 這也意味著,如果不同的資源庫具有相同的子套件名稱(通常是 model 子套件),則在從不同的資源庫導入不同的模型時,可能會出現導入錯誤。 解決這些導入錯誤的一種方法是從 sys.modules 字典中刪除有問題的子套件;更多詳細信息可以在 這個 GitHub 問題 中找到。
這裡值得一提的一個已知限制是:用戶**不能**在**同一個 Python 進程**中載入同一個資源庫的兩個不同分支。 這就像在 Python 中安裝兩個同名套件一樣,這是不好的。 如果您真的這樣做,快取可能會加入進來,給您帶來驚喜。 當然,在不同的進程中載入它們是完全可以的。