torch.set_default_dtype¶
- torch.set_default_dtype(d, /)[源][源]¶
將預設浮點 dtype 設定為
d。支援浮點 dtype 作為輸入。其他 dtype 將導致 torch 引發異常。當 PyTorch 初始化時,其預設浮點 dtype 為 torch.float32,set_default_dtype(torch.float64) 的目的是為了促進 NumPy 風格的型別推斷。預設浮點 dtype 用於:
隱式確定預設複數 dtype。當預設浮點型別為 float16 時,預設複數 dtype 為 complex32。對於 float32,預設複數 dtype 為 complex64。對於 float64,它為 complex128。對於 bfloat16,將引發異常,因為 bfloat16 沒有對應的複數型別。
推斷使用 Python 浮點數或複數 Python 數字構造的張量的 dtype。參見下面的示例。
確定布林和整數張量與 Python 浮點數和複數 Python 數字之間型別提升的結果。
- 引數
d (
torch.dtype) – 作為預設的浮點 dtype。
示例
>>> # initial default for floating point is torch.float32 >>> # Python floats are interpreted as float32 >>> torch.tensor([1.2, 3]).dtype torch.float32 >>> # initial default for floating point is torch.complex64 >>> # Complex Python numbers are interpreted as complex64 >>> torch.tensor([1.2, 3j]).dtype torch.complex64
>>> torch.set_default_dtype(torch.float64) >>> # Python floats are now interpreted as float64 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float64 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex128
>>> torch.set_default_dtype(torch.float16) >>> # Python floats are now interpreted as float16 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor torch.float16 >>> # Complex Python numbers are now interpreted as complex128 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor torch.complex32