torch.utils.module_tracker¶
此工具可用於跟蹤當前在 torch.nn.Module 層次結構中的位置。它可與其他跟蹤工具結合使用,以便輕鬆地將測量量與使用者友好的名稱關聯起來。目前,它特別用於 FlopCounterMode 中。
- class torch.utils.module_tracker.ModuleTracker[source][source]¶
ModuleTracker是一個上下文管理器,用於在執行期間跟蹤 nn.Module 層次結構,以便其他系統可以查詢當前正在執行哪個 Module(或其反向傳播)。您可以透過訪問此上下文管理器的
parents屬性來獲取當前正在執行的所有 Module 的集合,它們透過 fqn(完全限定名,也用作 state_dict 中的鍵)標識。您可以訪問is_bw屬性來了解當前是否正在執行反向傳播。注意,
parents永遠不為空,並且始終包含“Global”鍵。is_bw標誌在正向傳播後將保持True,直到執行另一個 Module。如果您需要它更精確,請提交一個 issue 來請求此功能。新增從 fqn 到模組例項的對映是可能的,但尚未完成,如果您需要此功能,請提交一個 issue 來請求。示例用法
mod = torch.nn.Linear(2, 2) with ModuleTracker() as tracker: # Access anything during the forward pass def my_linear(m1, m2, bias): print(f"Current modules: {tracker.parents}") return torch.mm(m1, m2.t()) + bias torch.nn.functional.linear = my_linear mod(torch.rand(2, 2))