torch.set_float32_matmul_precision¶
- torch.set_float32_matmul_precision(precision)[原始碼][原始碼]¶
設定 float32 矩陣乘法的內部精度。
以較低精度執行 float32 矩陣乘法可以顯著提升效能,在某些程式中,精度的損失影響微乎其微。
支援三種設定:
“highest”,float32 矩陣乘法使用 float32 資料型別(尾數位 24 位,其中 23 位顯式儲存)進行內部計算。
“high”,如果可用的快速矩陣乘法演算法支援,float32 矩陣乘法要麼使用 TensorFloat32 資料型別(尾數位 10 位顯式儲存),要麼將每個 float32 數字視為兩個 bfloat16 數字之和(尾數位約 16 位,其中 14 位顯式儲存)。否則,float32 矩陣乘法將按照“highest”精度進行計算。有關 bfloat16 方法的更多資訊,請參見下文。
“medium”,如果內部使用 bfloat16 資料型別的快速矩陣乘法演算法可用,float32 矩陣乘法將使用 bfloat16 資料型別(尾數位 8 位,其中 7 位顯式儲存)進行內部計算。否則,float32 矩陣乘法將按照“high”精度進行計算。
使用“high”精度時,float32 乘法可能會使用基於 bfloat16 的演算法,該演算法比簡單地截斷到較少尾數位(例如 TensorFloat32 的 10 位,bfloat16 顯式儲存的 7 位)更復雜。有關此演算法的完整描述,請參閱 [Henry2019]。在此簡要解釋一下,第一步是意識到我們可以將單個 float32 數字完美地編碼為三個 bfloat16 數字之和(因為 float32 有 23 個尾數位,而 bfloat16 有 7 個顯式儲存位,並且兩者具有相同的指數位數)。這意味著兩個 float32 數字的乘積可以精確地表示為九個 bfloat16 數字乘積之和。然後,我們可以透過丟棄其中一些乘積來權衡精度和速度。“high”精度演算法特別只保留了三個最重要的乘積,這方便地排除了涉及任一輸入最後 8 個尾數位的所有乘積。這意味著我們可以將輸入表示為兩個 bfloat16 數字之和,而不是三個。由於 bfloat16 乘加融合 (FMA) 指令通常比 float32 指令快 10 倍以上,因此使用 bfloat16 精度進行三次乘法和兩次加法比使用 float32 精度進行一次乘法更快。
注意
這不會改變 float32 矩陣乘法的輸出資料型別(dtype),它控制的是矩陣乘法的內部計算方式。
注意
這不會改變卷積操作的精度。其他標誌,例如 torch.backends.cudnn.allow_tf32,可能會控制卷積操作的精度。
注意
當前,此標誌僅影響一種原生裝置型別:CUDA。如果設定為“high”或“medium”,則在計算 float32 矩陣乘法時將使用 TensorFloat32 資料型別,這等同於設定 torch.backends.cuda.matmul.allow_tf32 = True。當設定為“highest”(預設值)時,內部計算使用 float32 資料型別,這等同於設定 torch.backends.cuda.matmul.allow_tf32 = False。
- 引數
precision (str) – 可以設定為“highest”(預設)、“high”或“medium”(參見上文)。