safe_int_mm¶
- torchao.quantization.safe_int_mm(input: Tensor, mat2: Tensor) Tensor [源代码]¶
执行安全的整数矩阵乘法,考虑了 torch.compile、cublas 和回退情况下的不同路径。
- 参数:
input (torch.Tensor) – 输入张量,形状为 [i, j]。
mat2 (torch.Tensor) – 用于乘法的矩阵,形状为 [j, k]。
- 返回值:
矩阵乘法的结果。
- 返回类型:
- 抛出:
AssertionError – 如果张量不在同一设备上。