MLX
 
Loading...
Searching...
No Matches
kernels.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <fmt/format.h>
4
5#include "mlx/array.h"
7
8namespace mlx::core {
9
10MTL::ComputePipelineState* get_arange_kernel(
12 const std::string& kernel_name,
13 const array& out);
14
15MTL::ComputePipelineState* get_unary_kernel(
17 const std::string& kernel_name,
18 Dtype in_type,
19 Dtype out_type,
20 const std::string op);
21
22MTL::ComputePipelineState* get_binary_kernel(
24 const std::string& kernel_name,
25 Dtype in_type,
26 Dtype out_type,
27 const std::string op);
28
29MTL::ComputePipelineState* get_binary_two_kernel(
31 const std::string& kernel_name,
32 Dtype in_type,
33 Dtype out_type,
34 const std::string op);
35
36MTL::ComputePipelineState* get_ternary_kernel(
38 const std::string& kernel_name,
39 Dtype type,
40 const std::string op);
41
42MTL::ComputePipelineState* get_copy_kernel(
44 const std::string& kernel_name,
45 const array& in,
46 const array& out);
47
48MTL::ComputePipelineState* get_dynamic_copy_kernel(
50 const std::string& kernel_name,
51 const array& in,
52 const array& out);
53
54MTL::ComputePipelineState* get_softmax_kernel(
56 const std::string& kernel_name,
57 bool precise,
58 const array& out);
59
60MTL::ComputePipelineState* get_scan_kernel(
62 const std::string& kernel_name,
63 bool reverse,
64 bool inclusive,
65 const std::string& reduce_type,
66 const array& in,
67 const array& out);
68
69MTL::ComputePipelineState* get_sort_kernel(
71 const std::string& kernel_name,
72 const array& in,
73 const array& out,
74 int bn,
75 int tn);
76
77MTL::ComputePipelineState* get_mb_sort_kernel(
79 const std::string& kernel_name,
80 const array& in,
81 const array& idx,
82 int bn,
83 int tn);
84
85MTL::ComputePipelineState* get_reduce_init_kernel(
87 const std::string& kernel_name,
88 const std::string& func_name,
89 const std::string& op_name,
90 const Dtype& out_type);
91
92MTL::ComputePipelineState* get_reduce_kernel(
94 const std::string& kernel_name,
95 const std::string& func_name,
96 const std::string& op_name,
97 const Dtype& in_type,
98 const Dtype& out_type,
99 const std::string& idx_t,
100 int ndim = -1,
101 int bm = -1,
102 int bn = -1);
103
104MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
105 metal::Device& d,
106 const std::string& kernel_name,
107 const std::string& hash_name,
108 const metal::MTLFCList& func_consts,
109 const array& out,
110 bool transpose_a,
111 bool transpose_b,
112 int bm,
113 int bn,
114 int bk,
115 int wm,
116 int wn);
117
118MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
119 metal::Device& d,
120 const std::string& kernel_name,
121 const array& in,
122 const array& out,
123 bool transpose_a,
124 bool transpose_b,
125 int bm,
126 int bn,
127 int bk,
128 int wm,
129 int wn,
130 bool mn_aligned,
131 bool k_aligned);
132
133MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
134 metal::Device& d,
135 const std::string& kernel_name,
136 const array& in,
137 const array& out,
138 bool axbpy);
139
140MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
141 metal::Device& d,
142 const std::string& kernel_name,
143 const array& out,
144 const std::optional<array>& mask_out,
145 const std::optional<array>& mask_op,
146 bool transpose_a,
147 bool transpose_b,
148 int bm,
149 int bn,
150 int bk,
151 int wm,
152 int wn,
153 bool mn_aligned,
154 bool k_aligned);
155
156MTL::ComputePipelineState* get_steel_conv_kernel(
157 metal::Device& d,
158 const std::string& kernel_name,
159 const array& out,
160 int bm,
161 int bn,
162 int bk,
163 int wm,
164 int wn,
165 int n_channel_specialization,
166 bool small_filter);
167
168MTL::ComputePipelineState* get_gemv_masked_kernel(
169 metal::Device& d,
170 const std::string& kernel_name,
171 const array& out,
172 const std::optional<array>& mask_out,
173 const std::optional<array>& mask_op,
174 bool transpose_mat,
175 int bm,
176 int bn,
177 int sm,
178 int sn,
179 int tm,
180 int tn,
181 bool contiguous);
182
183MTL::ComputePipelineState* get_steel_conv_general_kernel(
184 metal::Device& d,
185 const std::string& kernel_name,
186 const array& out,
187 int bm,
188 int bn,
189 int bk,
190 int wm,
191 int wn);
192
193MTL::ComputePipelineState* get_fft_kernel(
194 metal::Device& d,
195 const std::string& kernel_name,
196 const std::string& hash_name,
197 const metal::MTLFCList& func_consts,
198 const std::string& template_def);
199
200MTL::ComputePipelineState* get_quantized_kernel(
201 metal::Device& d,
202 const std::string& kernel_name,
203 const std::string& template_def);
204
205// Create a GPU kernel template definition for JIT compilation
206template <typename... Args>
207std::string
208get_template_definition(std::string name, std::string func, Args... args) {
209 std::ostringstream s;
210 s << func << "<";
211 bool first = true;
212 auto add_arg = [&s, &first](const auto& arg) {
213 if (!first) {
214 s << ", ";
215 }
216 first = false;
217 s << arg;
218 };
219 (add_arg(args), ...);
220 s << ">";
221 return fmt::format(
222 "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
223 name,
224 s.str());
225}
226
227} // namespace mlx::core
Definition array.h:24
Definition device.h:158
array contiguous(const array &a, bool allow_col_major=false, StreamOrDevice s={})
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:38
Definition allocator.h:7
MTL::ComputePipelineState * get_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const Dtype &in_type, const Dtype &out_type, const std::string &idx_t, int ndim=-1, int bm=-1, int bn=-1)
MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)
MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)
MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)
MTL::ComputePipelineState * get_quantized_kernel(metal::Device &d, const std::string &kernel_name, const std::string &template_def)
std::string get_template_definition(std::string name, std::string func, Args... args)
Definition kernels.h:208
MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
std::vector< array > Args
Definition export.h:11
MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
MTL::ComputePipelineState * get_dynamic_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const Dtype &out_type)
MTL::ComputePipelineState * get_scan_kernel(metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_mb_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)
MTL::ComputePipelineState * get_unary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
Definition dtype.h:13