torch.nn.utils.convert_conv3d_weight_memory_format¶
- torch.nn.utils.convert_conv3d_weight_memory_format(module, memory_format)[原始碼][原始碼]¶
將
nn.Conv3d.weight的memory_format轉換為指定的memory_format。此轉換遞迴應用於巢狀的nn.Module,包括module本身。請注意,它只改變 memory_format,而不改變每個維度的語義。此函式用於促使計算採用 NHWC 核,這在計算能力 >= 7.0 的 CUDA 裝置上為 fp16 資料提供了顯著的速度提升。注意
呼叫
model.to(memory_format=torch.channels_last_3d)比 utility 函式convert_conv3d_weight_memory_format更具侵略性。任何帶有 4d 權重的層都會受到model.to的影響,而這些層不一定能從轉換為指定的memory_format中受益。我們有信心的一點是 cuDNN 中卷積的 NDHWC(channels_last_3d) 轉換,因為即使在必須對輸入張量應用排列的情況下,在 NDHWC 中運行卷積也是有益的。因此,我們的策略是僅將卷積的權重轉換為 channels_last_3d。這確保了:1. 將使用快速卷積核,其好處可能超過排列的開銷(如果輸入不是相同格式)。2. 不會對不從 memory_format 轉換中受益的層應用不必要的排列。
最優情況是,卷積層之間的層與 channels last 格式相容。輸入張量在遇到第一個卷積層時會被排列為 channels last 格式並保持該記憶體格式。因此,後續的卷積不需要對其輸入張量進行排列。
在卷積層之間存在 channels last 不相容層的情況下,我們需要將該層的輸入張量重新排列回 contiguous 格式。輸入張量將以 contiguous 格式透過剩餘的層,並在遇到另一個卷積層時被排列為 channels last 格式。將該排列傳播到較早的層沒有意義,因為大多數層對
memory_format相當不敏感。當 PyTorch 支援排列融合時,這種說法可能會改變,因為可能存在比緊接在卷積之前更好的排列融合位置。
- 引數
module (nn.Module) –
nn.Conv3d和nn.ConvTranspose3d或容器nn.Modulememory_format – 使用者指定的
memory_format,例如torch.channels_last或torch.contiguous_format
- 返回值
更新了
nn.Conv3d的原始模組
示例
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda") >>> model = nn.Sequential( >>> nn.Conv3d(8, 4, 3)).cuda().half() >>> # This is identical to: >>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) >>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d) >>> out = model(input)