torch.cuda.jiterator._create_multi_output_jit_fn¶
- torch.cuda.jiterator._create_multi_output_jit_fn(code_string, num_outputs, **kwargs)[source][source]¶
为元素级操作创建一个 jiterator 生成的 CUDA 内核,该内核支持返回一个或多个输出。
- 参数
- 返回类型
示例
code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }" jitted_fn = create_jit_fn(code_string, alpha=1.0) a = torch.rand(3, device='cuda') b = torch.rand(3, device='cuda') # invoke jitted function like a regular python function result = jitted_fn(a, b, alpha=3.14)
警告
此 API 处于 beta 阶段,未来版本可能会更改。
警告
此 API 仅支持最多 8 个输入和 8 个输出