torch.utils.module_tracker¶
此工具可用於追蹤 torch.nn.Module 階層內部的目前位置。它可以在其他追蹤工具中使用,以便輕鬆地將測量量與使用者友善的名稱關聯起來。這在今天的 FlopCounterMode 中特別有用。
- class torch.utils.module_tracker.ModuleTracker[原始碼]¶
- ModuleTracker是一個上下文管理器,可在執行期間追蹤 nn.Module 階層,以便其他系統可以查詢目前正在執行哪個模組(或其反向正在執行)。- 您可以存取此上下文管理器上的 - parents屬性,以透過其 fqn(完整限定名稱,也用作 state_dict 中的鍵)取得目前正在執行的所有模組的集合。您可以存取- is_bw屬性,以瞭解目前是否在反向執行中。- 請注意, - parents永遠不會是空的,並且始終包含「Global」鍵。在正向執行之後,直到執行另一個模組之前,- is_bw旗標將保持- True。如果您需要它更準確,請提交一個問題來請求此功能。新增從 fqn 到模組實例的映射是可能的,但尚未完成,如果您需要,請提交一個問題來請求此功能。- 使用範例 - mod = torch.nn.Linear(2, 2) with ModuleTracker() as tracker: # Access anything during the forward pass def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias torch.nn.functional.linear = my_linear mod(torch.rand(2, 2))