torch.testing¶
- torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[原始碼]¶
- 斷言 - actual和- expected相近。- 如果 - actual和- expected是跨步的、非量化的、實數值的和有限的,則如果滿足以下條件,則認為它們相近:- 非有限值( - -inf和- inf)僅在相等時才視為相近。只有在- equal_nan為- True時,- NaN才被視為彼此相等。- 此外,只有當它們具有相同的: - 裝置(如果- check_device為- True),
- dtype(如果- check_dtype為- True),
- layout(如果- check_layout為- True),以及
- stride(如果 - check_stride為- True)時,才視為相近。
 - 如果 - actual或- expected是中繼張量,則只會執行屬性檢查。- 如果 - actual和- expected是稀疏的(具有 COO、CSR、CSC、BSR 或 BSC 佈局),則會分別檢查其跨步成員。索引,即 COO 的- indices、CSR 和 BSR 的- crow_indices和- col_indices,或 CSC 和 BSC 佈局的- ccol_indices和- row_indices,始終會檢查是否相等,而值則會根據上述定義檢查是否相近。- 如果 - actual和- expected已被量化,則當它們具有相同的- qscheme()並且- dequantize()的結果根據上述定義接近時,它們被視為接近。- actual和- expected可以是- Tensor或任何可以使用- torch.as_tensor()從中構建- torch.Tensor的類張量或類標量。除了 Python 標量之外,輸入類型必須直接相關。此外,- actual和- expected可以是- Sequence或- Mapping,在這種情況下,如果它們的結構匹配並且根據上述定義,它們的所有元素都被視為接近,則它們被視為接近。- 備註 - Python 標量是類型關係要求的例外,因為它們的 - type(),即- int、- float和- complex,等同於類張量的- dtype。因此,可以檢查不同類型的 Python 標量,但需要- check_dtype=False。- 參數
- actual (任何) – 實際輸入。 
- expected (任何) – 預期輸入。 
- allow_subclasses (布林值) – 如果為 - True(預設值),並且除了 Python 標量之外,允許直接相關類型的輸入。否則,需要類型相等。
- rtol (可選[浮點數]) – 相對容忍度。如果指定,則還必須指定 - atol。如果省略,則使用下表根據- dtype選擇預設值。
- atol (可選[浮點數]) – 絕對容忍度。如果指定,則還必須指定 - rtol。如果省略,則使用下表根據- dtype選擇預設值。
- check_device (布林值) – 如果為 - True(預設值),則斷言相應的張量在同一個- 裝置上。如果禁用此檢查,則不同- 裝置上的張量將在比較之前移至 CPU。
- check_dtype (布林值) – 如果為 - True(預設值),則斷言相應的張量具有相同的- dtype。如果禁用此檢查,則具有不同- dtype的張量將在比較之前提升為通用- dtype(根據- torch.promote_types())。
- check_layout (布林值) – 如果為 - True(預設值),則斷言相應的張量具有相同的- 佈局。如果禁用此檢查,則具有不同- 佈局的張量將在比較之前轉換為跨步張量。
- check_stride (布林值) – 如果為 - True且相應的張量是跨步的,則斷言它們具有相同的跨度。
- msg (可選[聯集[字串, 可呼叫[[字串], 字串]]]) – 在比較期間發生錯誤時要使用的可選錯誤消息。也可以作為可呼叫對象傳遞,在這種情況下,它將使用生成的消息進行呼叫,並且應返回新的消息。 
 
- 引發
- ValueError – 如果無法從輸入構建 - torch.Tensor。
- ValueError – 如果僅指定 - rtol或- atol。
- AssertionError – 如果相應的輸入不是 Python 標量並且沒有直接相關。 
- AssertionError – 如果 - allow_subclasses為- False,但相應的輸入不是 Python 標量並且具有不同的類型。
- AssertionError – 如果輸入是 - Sequence,但它們的長度不匹配。
- AssertionError – 如果輸入是 - Mapping,但它們的鍵集不匹配。
- AssertionError – 如果相應的張量沒有相同的 - 形狀。
- AssertionError – 如果 - check_layout為- True,但相應的張量沒有相同的- 佈局。
- AssertionError – 如果相應的張量中只有一個被量化。 
- AssertionError – 如果相應的張量已被量化,但具有不同的 - qscheme()。
- AssertionError – 如果 - check_device為- True,但相應的張量不在同一個- 裝置上。
- AssertionError – 如果 - check_dtype為- True,但相應的張量沒有相同的- dtype。
- AssertionError – 如果 - check_stride為- True,但相應的跨步張量沒有相同的跨度。
- AssertionError – 如果相應張量的值根據上述定義不接近。 
 
 - 下表顯示了不同 - dtype的預設- rtol和- atol。如果- dtype不匹配,則使用兩個容忍度中的最大值。- dtype- rtol- atol- float16- 1e-3- 1e-5- bfloat16- 1.6e-2- 1e-5- float32- 1.3e-6- 1e-5- float64- 1e-7- 1e-7- complex32- 1e-3- 1e-5- complex64- 1.3e-6- 1e-5- complex128- 1e-7- 1e-7- quint8- 1.3e-6- 1e-5- quint2x4- 1.3e-6- 1e-5- quint4x2- 1.3e-6- 1e-5- qint8- 1.3e-6- 1e-5- qint32- 1.3e-6- 1e-5- 其他 - 0.0- 0.0- 備註 - assert_close()具備高度可配置性和嚴格的默認設定。建議使用者使用- partial()來調整以符合其使用案例。例如,如果需要進行相等性檢查,可以定義一個- assert_equal,它默認對每個- dtype使用零容忍度- >>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Expected 1e-10 but got 1e-09. Absolute difference: 9.000000000000001e-10 Relative difference: 9.0 - 範例 - >>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected) - >>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected) - >>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected) - >>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected) - >>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected) - >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> and <class 'torch.Tensor'>. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False) - >>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Expected nan but got nan. Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True) - >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # If msg is a callable, it can be used to augment the generated message with >>> # extra information >>> torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header Tensor-likes are not close! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) Footer 
- torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[source]¶
- 建立具有給定 - shape、- device和- dtype的張量,並填充從- [low, high)均勻繪製的值。- 如果指定了 - low或- high並且它們超出了- dtype可表示的有限值範圍,則它們將分別被限制為最低或最高可表示的有限值。如果為- None,則下表描述了- low和- high的默認值,它們取決於- dtype。- dtype- 低- 高- 布林類型 - 0- 2- 無符號整數類型 - 0- 10- 有符號整數類型 - -9- 10- 浮點數類型 - -9- 9- 複數類型 - -9- 9- 參數
- shape (Tuple[int, ...]) – 單個整數或定義輸出張量形狀的整數序列。 
- dtype ( - torch.dtype) – 返回張量的資料類型。
- device (Union[str, torch.device]) – 返回張量的設備。 
- low (Optional[Number]) – 設定給定範圍的下限(含)。如果提供了一個數字,它將被限制為給定 dtype 的最小可表示有限值。當為 - None(默認)時,此值將根據- dtype確定(請參閱上表)。默認值:- None。
- high (Optional[Number]) – - 設定給定範圍的上限(不含)。如果提供了一個數字,它將被限制為給定 dtype 的最大可表示有限值。當為 - None(默認)時,此值將根據- dtype確定(請參閱上表)。默認值:- None。- 自版本 2.1 起已棄用: 自 2.1 版本起,將 - low==high傳遞給- make_tensor()(用於浮點數或複數類型)已被棄用,並將在 2.3 版本中移除。請改用- torch.full()。
- requires_grad (Optional[bool]) – 自動求導是否應記錄在返回張量上的操作。默認值: - False。
- noncontiguous (Optional[bool]) – 如果為 True,則返回的張量將是非連續的。如果構造的張量少於兩個元素,則忽略此參數。與 - memory_format互斥。
- exclude_zero (Optional[bool]) – 如果為 - True,則零將根據- dtype替換為 dtype 的小正值。對於布林值和整數類型,零將替換為一。對於浮點數類型,它將替換為 dtype 的最小正正規數(- dtype的- finfo()物件的「微小」值),對於複數類型,它將替換為一個複數,其實部和虛部都是複數類型可表示的最小正正規數。默認值為- False。
- memory_format (Optional[torch.memory_format]) – 返回張量的記憶體格式。與 - noncontiguous互斥。
 
- 引發
- ValueError – 如果針對整數 dtype 傳遞了 - requires_grad=True
- ValueError – 如果 - low >= high。
- ValueError – 如果 - low或- high為- nan。
- ValueError – 如果同時傳遞了 - noncontiguous和- memory_format。
- TypeError – 如果此函數不支援 - dtype。
 
- 傳回類型
 - 範例 - >>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0') 
- torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')[source]¶
- 警告 - torch.testing.assert_allclose()自- 1.12版本起已棄用,並將在未來版本中移除。請改用- torch.testing.assert_close()。您可以在 這裡 找到詳細的升級說明。