create_feature_extractor¶
- torchvision.models.feature_extraction.create_feature_extractor(model: Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) GraphModule[源]¶
建立一個新的圖模組,該模組將給定模型中的中間節點作為字典返回,其中使用者指定的鍵為字串,請求的輸出作為值。這透過使用 FX 重寫模型的計算圖來實現,從而將所需的節點作為輸出返回。所有未使用的節點及其相應的引數都將被移除。
所需的輸出節點必須指定為透過
.分隔的路徑,該路徑從頂層模組向下遍歷模組層次結構,直到葉操作或葉模組。有關此處使用的節點命名約定的更多詳細資訊,請參閱文件中的相關小標題。並非所有模型都可以透過 FX 進行跟蹤,儘管經過一些調整後它們可以協同工作。以下是一些(非詳盡的)技巧列表
如果您不需要跟蹤某個特定、有問題的子模組,可以透過將
leaf_modules列表作為tracer_kwargs之一傳遞來將其轉換為“葉模組”(參見下面的示例)。它不會被跟蹤,而是生成的圖將保留對該模組 forward 方法的引用。同樣,您可以透過將
autowrap_functions列表作為tracer_kwargs之一傳遞來將函式轉換為葉函式(參見下面的示例)。一些內建的 Python 函式可能會有問題。例如,
int在跟蹤期間會引發錯誤。您可以將它們包裝在您自己的函式中,然後將其作為tracer_kwargs之一傳遞給autowrap_functions。
有關 FX 的更多資訊,請參閱torch.fx 文件。
- 引數:
model (nn.Module) – 我們將從中提取特徵的模型
return_nodes (list or dict, optional) – 一個
List或Dict,包含將返回其啟用值的節點的名稱(或部分名稱 - 見上方註釋)。如果它是一個Dict,鍵是節點名稱,值是圖模組返回字典的使用者指定鍵。如果它是一個List,它被視為一個Dict,將節點規範字符串直接對映到輸出名稱。如果指定了train_return_nodes和eval_return_nodes,則不應指定此引數。train_return_nodes (list or dict, optional) – 類似於
return_nodes。如果在訓練模式下的返回節點與評估模式下的不同,則可以使用此引數。如果指定了此引數,則必須同時指定eval_return_nodes,並且不應指定return_nodes。eval_return_nodes (list or dict, optional) – 類似於
return_nodes。如果在訓練模式下的返回節點與評估模式下的不同,則可以使用此引數。如果指定了此引數,則必須同時指定train_return_nodes,並且不應指定 return_nodes。tracer_kwargs (dict, optional) –
NodePathTracer的關鍵字引數字典(NodePathTracer會將其傳遞給其父類 torch.fx.Tracer)。預設情況下,它將設定為包裝所有 torchvision 運算元並使其成為葉節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供了 tracer_kwargs,上述預設引數將被附加到使用者提供的字典中。suppress_diff_warning (bool, optional) – 當訓練和評估版本的圖之間存在差異時,是否抑制警告。預設為 False。
concrete_args (Optional[Dict[str, any]]) – 不應被視為 Proxies 的具體引數。根據PyTorch 文件,此引數的 API 可能無法保證。
示例
>>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> model = create_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> # Specifying leaf modules and leaf functions >>> def leaf_function(x): >>> # This would raise a TypeError if traced through >>> return int(x) >>> >>> class LeafModule(torch.nn.Module): >>> def forward(self, x): >>> # This would raise a TypeError if traced through >>> int(x.shape[0]) >>> return torch.nn.functional.relu(x + 4) >>> >>> class MyModule(torch.nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv = torch.nn.Conv2d(3, 1, 3) >>> self.leaf_module = LeafModule() >>> >>> def forward(self, x): >>> leaf_function(x.shape[0]) >>> x = self.conv(x) >>> return self.leaf_module(x) >>> >>> model = create_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]})