safe_int_mm¶
- torchao.quantization.safe_int_mm(input: Tensor, mat2: Tensor) Tensor [source]¶
执行安全的整数矩阵乘法,考虑 torch.compile、cublas 和回退情况的不同路径。
- 参数:
input (torch.Tensor) – 形状为 [i, j] 的输入张量。
mat2 (torch.Tensor) – 要与之相乘的矩阵,形状为 [j, k]。
- 返回:
矩阵乘法的结果。
- 返回类型:
- Raises:
AssertionError – 如果张量不在同一设备上。