MLX
Loading...
Searching...
No Matches
kernels.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <fmt/format.h>
4
5#include "mlx/array.h"
7
8namespace mlx::core {
9
10MTL::ComputePipelineState* get_arange_kernel(
12 const std::string& kernel_name,
13 const array& out);
14
15MTL::ComputePipelineState* get_unary_kernel(
17 const std::string& kernel_name,
18 Dtype out_type,
19 const std::string op);
20
21MTL::ComputePipelineState* get_binary_kernel(
23 const std::string& kernel_name,
24 Dtype in_type,
25 Dtype out_type,
26 const std::string op);
27
28MTL::ComputePipelineState* get_binary_two_kernel(
30 const std::string& kernel_name,
31 Dtype in_type,
32 Dtype out_type,
33 const std::string op);
34
35MTL::ComputePipelineState* get_ternary_kernel(
37 const std::string& kernel_name,
38 Dtype type,
39 const std::string op);
40
41MTL::ComputePipelineState* get_copy_kernel(
43 const std::string& kernel_name,
44 const array& in,
45 const array& out);
46
47MTL::ComputePipelineState* get_softmax_kernel(
49 const std::string& kernel_name,
50 bool precise,
51 const array& out);
52
53MTL::ComputePipelineState* get_scan_kernel(
55 const std::string& kernel_name,
56 bool reverse,
57 bool inclusive,
58 const std::string& reduce_type,
59 const array& in,
60 const array& out);
61
62MTL::ComputePipelineState* get_sort_kernel(
64 const std::string& kernel_name,
65 const array& in,
66 const array& out,
67 int bn,
68 int tn);
69
70MTL::ComputePipelineState* get_mb_sort_kernel(
72 const std::string& kernel_name,
73 const array& in,
74 const array& idx,
75 int bn,
76 int tn);
77
78MTL::ComputePipelineState* get_reduce_init_kernel(
80 const std::string& kernel_name,
81 const array& out);
82
83MTL::ComputePipelineState* get_reduce_kernel(
85 const std::string& kernel_name,
86 const std::string& func_name,
87 const std::string& op_name,
88 const array& in,
89 const array& out,
90 int ndim = -1,
91 int bm = -1,
92 int bn = -1);
93
94MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
96 const std::string& kernel_name,
97 const std::string& hash_name,
98 const metal::MTLFCList& func_consts,
99 const array& out,
100 bool transpose_a,
101 bool transpose_b,
102 int bm,
103 int bn,
104 int bk,
105 int wm,
106 int wn);
107
108MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
109 metal::Device& d,
110 const std::string& kernel_name,
111 const array& in,
112 const array& out,
113 bool transpose_a,
114 bool transpose_b,
115 int bm,
116 int bn,
117 int bk,
118 int wm,
119 int wn,
120 bool mn_aligned,
121 bool k_aligned);
122
123MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
124 metal::Device& d,
125 const std::string& kernel_name,
126 const array& in,
127 const array& out,
128 bool axbpy);
129
130MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
131 metal::Device& d,
132 const std::string& kernel_name,
133 const array& out,
134 const std::optional<array>& mask_out,
135 const std::optional<array>& mask_op,
136 bool transpose_a,
137 bool transpose_b,
138 int bm,
139 int bn,
140 int bk,
141 int wm,
142 int wn,
143 bool mn_aligned,
144 bool k_aligned);
145
146MTL::ComputePipelineState* get_steel_conv_kernel(
147 metal::Device& d,
148 const std::string& kernel_name,
149 const array& out,
150 int bm,
151 int bn,
152 int bk,
153 int wm,
154 int wn,
155 int n_channel_specialization,
156 bool small_filter);
157
158MTL::ComputePipelineState* get_gemv_masked_kernel(
159 metal::Device& d,
160 const std::string& kernel_name,
161 const array& out,
162 const std::optional<array>& mask_out,
163 const std::optional<array>& mask_op,
164 bool transpose_mat,
165 int bm,
166 int bn,
167 int sm,
168 int sn,
169 int tm,
170 int tn,
171 bool contiguous);
172
173MTL::ComputePipelineState* get_steel_conv_general_kernel(
174 metal::Device& d,
175 const std::string& kernel_name,
176 const array& out,
177 int bm,
178 int bn,
179 int bk,
180 int wm,
181 int wn);
182
183MTL::ComputePipelineState* get_fft_kernel(
184 metal::Device& d,
185 const std::string& kernel_name,
186 const std::string& hash_name,
187 const metal::MTLFCList& func_consts,
188 const std::string& template_def);
189
190MTL::ComputePipelineState* get_quantized_kernel(
191 metal::Device& d,
192 const std::string& kernel_name,
193 const std::string& template_def);
194
195// Create a GPU kernel template definition for JIT compilation
196template <typename... Args>
197std::string
198get_template_definition(std::string name, std::string func, Args... args) {
199 std::ostringstream s;
200 s << func << "<";
201 bool first = true;
202 auto add_arg = [&s, &first](const auto& arg) {
203 if (!first) {
204 s << ", ";
205 }
206 first = false;
207 s << arg;
208 };
209 (add_arg(args), ...);
210 s << ">";
211 std::string base_string = R"(
212template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
213 )";
214 return fmt::format(base_string, name, s.str());
215}
216
217} // namespace mlx::core
Definition array.h:20
Definition device.h:87
Op op
Definition binary.h:129
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:38
Definition allocator.h:7
MTL::ComputePipelineState * get_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_unary_kernel(metal::Device &d, const std::string &kernel_name, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)
MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)
MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &in, const array &out, int ndim=-1, int bm=-1, int bn=-1)
MTL::ComputePipelineState * get_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)
MTL::ComputePipelineState * get_quantized_kernel(metal::Device &d, const std::string &kernel_name, const std::string &template_def)
std::string get_template_definition(std::string name, std::string func, Args... args)
Definition kernels.h:198
MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
MTL::ComputePipelineState * get_scan_kernel(metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_mb_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)
Definition dtype.h:13