get_graph_node_names¶
- torchvision.models.feature_extraction.get_graph_node_names(model: Module, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) Tuple[List[str], List[str]][原始碼]¶
用於按執行順序返回節點名稱的開發實用工具。參閱關於節點名稱的說明,位於
create_feature_extractor()下方。有助於檢視哪些節點名稱可用於特徵提取。節點名稱不能直接從模型程式碼中輕鬆讀取的原因有兩個並非所有子模組都會被追蹤。`torch.nn` 中的所有模組都屬於此類。
表示重複應用同一操作或葉子模組的節點會帶有一個 `_{counter}` 字尾。
模型會被追蹤兩次:一次在訓練模式下,一次在評估模式下。兩次追蹤得到的節點名稱列表都會被返回。
有關此處使用的節點命名約定的更多詳細資訊,請參閱 相關副標題 在 文件中。
- 引數:
model (nn.Module) – 我們想要列印節點名稱的模型
tracer_kwargs (dict, 可選) – `NodePathTracer` 的關鍵字引數字典(它們最終會傳遞給 torch.fx.Tracer)。預設情況下,它將被設定為包裝所有 torchvision 操作並使其成為葉子節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供了 `tracer_kwargs`,上述預設引數將附加到使用者提供的字典中。
suppress_diff_warning (bool, 可選) – 當圖的訓練版本和評估版本存在差異時,是否抑制警告。預設為 False。
concrete_args (Optional[Dict[str, any]]) – 不應被視為 Proxies 的具體引數。根據 Pytorch 文件,此引數的 API 可能無法保證。
- 返回:
一個列表,包含在訓練模式下追蹤模型得到的節點名稱;另一個列表,包含在評估模式下追蹤模型得到的節點名稱。
- 返回型別:
示例
>>> model = torchvision.models.resnet18() >>> train_nodes, eval_nodes = get_graph_node_names(model)