快捷方式

create_feature_extractor

torchvision.models.feature_extraction.create_feature_extractor(model: Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) GraphModule[源]

建立一個新的圖模組,該模組將給定模型中的中間節點作為字典返回,其中使用者指定的鍵為字串,請求的輸出作為值。這透過使用 FX 重寫模型的計算圖來實現,從而將所需的節點作為輸出返回。所有未使用的節點及其相應的引數都將被移除。

所需的輸出節點必須指定為透過 . 分隔的路徑,該路徑從頂層模組向下遍歷模組層次結構,直到葉操作或葉模組。有關此處使用的節點命名約定的更多詳細資訊,請參閱文件中的相關小標題

並非所有模型都可以透過 FX 進行跟蹤,儘管經過一些調整後它們可以協同工作。以下是一些(非詳盡的)技巧列表

  • 如果您不需要跟蹤某個特定、有問題的子模組,可以透過將 leaf_modules 列表作為 tracer_kwargs 之一傳遞來將其轉換為“葉模組”(參見下面的示例)。它不會被跟蹤,而是生成的圖將保留對該模組 forward 方法的引用。

  • 同樣,您可以透過將 autowrap_functions 列表作為 tracer_kwargs 之一傳遞來將函式轉換為葉函式(參見下面的示例)。

  • 一些內建的 Python 函式可能會有問題。例如,int 在跟蹤期間會引發錯誤。您可以將它們包裝在您自己的函式中,然後將其作為 tracer_kwargs 之一傳遞給 autowrap_functions

有關 FX 的更多資訊,請參閱torch.fx 文件

引數:
  • model (nn.Module) – 我們將從中提取特徵的模型

  • return_nodes (list or dict, optional) – 一個 ListDict,包含將返回其啟用值的節點的名稱(或部分名稱 - 見上方註釋)。如果它是一個 Dict,鍵是節點名稱,值是圖模組返回字典的使用者指定鍵。如果它是一個 List,它被視為一個 Dict,將節點規範字符串直接對映到輸出名稱。如果指定了 train_return_nodeseval_return_nodes,則不應指定此引數。

  • train_return_nodes (list or dict, optional) – 類似於 return_nodes。如果在訓練模式下的返回節點與評估模式下的不同,則可以使用此引數。如果指定了此引數,則必須同時指定 eval_return_nodes,並且不應指定 return_nodes

  • eval_return_nodes (list or dict, optional) – 類似於 return_nodes。如果在訓練模式下的返回節點與評估模式下的不同,則可以使用此引數。如果指定了此引數,則必須同時指定 train_return_nodes,並且不應指定 return_nodes

  • tracer_kwargs (dict, optional) – NodePathTracer 的關鍵字引數字典(NodePathTracer 會將其傳遞給其父類 torch.fx.Tracer)。預設情況下,它將設定為包裝所有 torchvision 運算元並使其成為葉節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供了 tracer_kwargs,上述預設引數將被附加到使用者提供的字典中。

  • suppress_diff_warning (bool, optional) – 當訓練和評估版本的圖之間存在差異時,是否抑制警告。預設為 False。

  • concrete_args (Optional[Dict[str, any]]) – 不應被視為 Proxies 的具體引數。根據PyTorch 文件,此引數的 API 可能無法保證。

示例

>>> # Feature extraction with resnet
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> model = create_feature_extractor(
>>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = model(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>>     [('feat1', torch.Size([1, 64, 56, 56])),
>>>      ('feat2', torch.Size([1, 256, 14, 14]))]

>>> # Specifying leaf modules and leaf functions
>>> def leaf_function(x):
>>>     # This would raise a TypeError if traced through
>>>     return int(x)
>>>
>>> class LeafModule(torch.nn.Module):
>>>     def forward(self, x):
>>>         # This would raise a TypeError if traced through
>>>         int(x.shape[0])
>>>         return torch.nn.functional.relu(x + 4)
>>>
>>> class MyModule(torch.nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.conv = torch.nn.Conv2d(3, 1, 3)
>>>         self.leaf_module = LeafModule()
>>>
>>>     def forward(self, x):
>>>         leaf_function(x.shape[0])
>>>         x = self.conv(x)
>>>         return self.leaf_module(x)
>>>
>>> model = create_feature_extractor(
>>>     MyModule(), return_nodes=['leaf_module'],
>>>     tracer_kwargs={'leaf_modules': [LeafModule],
>>>                    'autowrap_functions': [leaf_function]})

文件

訪問 PyTorch 全面的開發者文件

檢視文件

教程

獲取面向初學者和高階開發者的深入教程

檢視教程

資源

查詢開發資源並獲得問題解答

檢視資源