torch.testing¶
- torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)[source]¶
- 断言 - actual和- expected接近。- 如果 - actual和- expected是带步长的、非量化的、实值且有限的,则当它们满足以下条件时,它们被认为是接近的:- 非有限值( - -inf和- inf)仅当它们相等时才被视为接近。- NaN仅当- equal_nan为- True时才被视为彼此相等。- 此外,它们只有在具有相同的 - device(如果- check_device为- True),
- dtype(如果- check_dtype为- True),
- layout(如果- check_layout为- True),以及
- 步幅(如果 - check_stride为- True)时才被视为接近。
 - 如果 - actual或- expected是元张量,则只执行属性检查。- 如果 - actual和- expected是稀疏矩阵(具有 COO、CSR、CSC、BSR 或 BSC 布局),则分别检查它们的步长成员。索引,即 COO 的- indices,CSR 和 BSR 的- crow_indices和- col_indices,或 CSC 和 BSC 布局的- ccol_indices和- row_indices,始终检查是否相等,而值则根据上述定义检查是否接近。- 如果 - actual和- expected是量化的,则如果它们具有相同的- qscheme()并且- dequantize()的结果根据上述定义接近,则它们被认为是接近的。- actual和- expected可以是- Tensor或任何可以使用- torch.Tensor构造的张量或标量类型,可以使用- torch.as_tensor()。除了 Python 标量之外,输入类型必须直接相关。此外,- actual和- expected可以是- Sequence或- Mapping,在这种情况下,如果它们的结构匹配并且所有元素根据上述定义被认为是接近的,则它们被认为是接近的。- 注意 - Python 标量是类型关系要求的例外,因为它们的 - type(),即- int、- float和- complex,等效于张量类型的- dtype。因此,可以检查不同类型的 Python 标量,但需要- check_dtype=False。- 参数
- actual (Any) – 实际输入。 
- expected (Any) – 预期输入。 
- allow_subclasses (bool) – 如果为 - True(默认)并且除了 Python 标量之外,允许直接相关类型的输入。否则需要类型相等。
- rtol (Optional[float]) – 相对容差。如果指定了 - atol,则也必须指定- atol。如果省略,则根据- dtype选择以下表格中的默认值。
- atol (可选[float]) – 绝对容差。如果指定了 - rtol,则也必须指定- atol。如果省略,则根据- dtype选择以下表格中的默认值。
- check_device (bool) – 如果为 - True(默认),则断言相应的张量位于相同的- device上。如果禁用此检查,则位于不同- device上的张量将在比较之前被移动到 CPU 上。
- check_dtype (bool) – 如果为 - True(默认),则断言相应的张量具有相同的- dtype。如果禁用此检查,则具有不同- dtype的张量将在比较之前被提升为公共- dtype(根据- torch.promote_types())。
- check_layout (bool) – 如果为 - True(默认),则断言相应的张量具有相同的- layout。如果禁用此检查,则具有不同- layout的张量将在比较之前被转换为带步长的张量。
- check_stride (bool) – 如果为 - True且相应的张量为带步长的,则断言它们具有相同的步长。
- msg (可选[Union[str, Callable[[str], str]]]) – 比较过程中发生错误时使用的可选错误消息。也可以作为可调用对象传递,在这种情况下,它将使用生成的 messages 进行调用,并应返回新的 messages。 
 
- 引发异常
- ValueError – 如果无法从输入中构建 - torch.Tensor。
- ValueError – 如果只指定了 - rtol或- atol。
- AssertionError – 如果对应的输入不是 Python 标量,并且没有直接关系。 
- AssertionError – 如果 - allow_subclasses为- False,但对应的输入不是 Python 标量,并且类型不同。
- AssertionError – 如果输入是 - Sequence,但它们的长度不匹配。
- AssertionError – 如果输入是 - Mapping,但它们的键集不匹配。
- AssertionError – 如果对应的张量没有相同的 - shape。
- AssertionError – 如果 - check_layout为- True,但相应的张量没有相同的- layout。
- AssertionError – 如果相应的张量中只有一个被量化。 
- AssertionError – 如果相应的张量被量化,但具有不同的 - qscheme()。
- AssertionError – 如果 - check_device为- True,但相应的张量不在同一个- device上。
- AssertionError – 如果 - check_dtype为- True,但相应的张量没有相同的- dtype。
- AssertionError – 如果 - check_stride为- True,但相应的带步长张量没有相同的步长。
- AssertionError – 如果相应的张量的值根据上述定义不接近。 
 
 - 下表显示了不同 - dtype的默认- rtol和- atol。如果- dtype不匹配,则使用两个容差中的最大值。- dtype- rtol- atol- float16- 1e-3- 1e-5- bfloat16- 1.6e-2- 1e-5- float32- 1.3e-6- 1e-5- float64- 1e-7- 1e-7- complex32- 1e-3- 1e-5- complex64- 1.3e-6- 1e-5- complex128- 1e-7- 1e-7- quint8- 1.3e-6- 1e-5- quint2x4- 1.3e-6- 1e-5- quint4x2- 1.3e-6- 1e-5- qint8- 1.3e-6- 1e-5- qint32- 1.3e-6- 1e-5- 其他 - 0.0- 0.0- 注意 - assert_close()具有高度可配置的严格默认设置。鼓励用户使用- partial()来适应他们的用例。例如,如果需要进行相等性检查,则可以定义一个- assert_equal,它默认情况下对每个- dtype使用零容差。- >>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Expected 1e-10 but got 1e-09. Absolute difference: 9.000000000000001e-10 Relative difference: 9.0 - 示例 - >>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected) - >>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected) - >>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected) - >>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected) - >>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected) - >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>. >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'> and <class 'torch.Tensor'>. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False) - >>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Expected nan but got nan. Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True) - >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # If msg is a callable, it can be used to augment the generated message with >>> # extra information >>> torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header Tensor-likes are not close! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) Footer 
- torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)[source]¶
- 创建一个具有给定 - shape、- device和- dtype的张量,并用从- [low, high)均匀抽取的值填充。- 如果指定了 - low或- high并且它们超出了- dtype可表示的有限值范围,则它们分别被钳制到可表示的最低或最高有限值。如果为- None,则下表描述了- low和- high的默认值,它们取决于- dtype。- dtype- low- high- 布尔类型 - 0- 2- 无符号整数类型 - 0- 10- 有符号整数类型 - -9- 10- 浮点类型 - -9- 9- 复数类型 - -9- 9- 参数
- shape (Tuple[int, ...]) – 定义输出张量形状的单个整数或整数序列。 
- dtype ( - torch.dtype) – 返回张量的数据类型。
- device (Union[str, torch.device]) – 返回张量的设备。 
- low (Optional[Number]) – 设置给定范围的下限(包含)。如果提供数字,则将其钳制到给定 dtype 的最小可表示有限值。当为 - None(默认)时,此值根据- dtype确定(参见上表)。默认值:- None。
- high (Optional[Number]) – - 设置给定范围的上限(不包含)。如果提供数字,则将其钳制到给定 dtype 的最大可表示有限值。当为 - None(默认)时,此值根据- dtype确定(参见上表)。默认值:- None。- 自版本 2.1 起已弃用: 自 2.1 起,将 - low==high传递给- make_tensor()用于浮点或复数类型已弃用,将在 2.3 中删除。请改用- torch.full()。
- requires_grad (Optional[bool]) – 是否应记录对返回张量的操作。默认值: - False。
- 非连续 (可选[布尔值]) – 如果为 True,则返回的张量将是非连续的。如果构造的张量元素少于两个,则忽略此参数。与 - memory_format互斥。
- 排除零 (可选[布尔值]) – 如果为 - True,则将零替换为根据- dtype的数据类型的小正值。对于布尔值和整数类型,零将被替换为 1。对于浮点类型,它将被替换为数据类型的最小正规数(- dtype的- finfo()对象的“微小”值),对于复数类型,它将被替换为实部和虚部都是复数类型可表示的最小正规数的复数。默认值为- False。
- 内存格式 (可选[torch.memory_format]) – 返回的张量的内存格式。与 - noncontiguous互斥。
 
- 引发异常
- ValueError – 如果为整数 dtype 传递了 - requires_grad=True
- ValueError – 如果 - low >= high。
- ValueError – 如果 - low或- high为- nan。
- ValueError – 如果同时传递了 - noncontiguous和- memory_format。
- TypeError – 如果 - dtype不受此函数支持。
 
- 返回类型
 - 示例 - >>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0') 
- torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')[source]¶
- 警告 - torch.testing.assert_allclose()自- 1.12版本起已弃用,将在未来版本中移除。请使用- torch.testing.assert_close()代替。您可以在 此处 找到详细的升级说明。