MLX
Loading...
Searching...
No Matches
kernels.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <fmt/format.h>
4
5#include "mlx/array.h"
7
8namespace mlx::core {
9
10MTL::ComputePipelineState* get_arange_kernel(
12 const std::string& kernel_name,
13 const array& out);
14
15MTL::ComputePipelineState* get_unary_kernel(
17 const std::string& kernel_name,
18 Dtype in_type,
19 Dtype out_type,
20 const std::string op);
21
22MTL::ComputePipelineState* get_binary_kernel(
24 const std::string& kernel_name,
25 Dtype in_type,
26 Dtype out_type,
27 const std::string op);
28
29MTL::ComputePipelineState* get_binary_two_kernel(
31 const std::string& kernel_name,
32 Dtype in_type,
33 Dtype out_type,
34 const std::string op);
35
36MTL::ComputePipelineState* get_ternary_kernel(
38 const std::string& kernel_name,
39 Dtype type,
40 const std::string op);
41
42MTL::ComputePipelineState* get_copy_kernel(
44 const std::string& kernel_name,
45 const array& in,
46 const array& out);
47
48MTL::ComputePipelineState* get_softmax_kernel(
50 const std::string& kernel_name,
51 bool precise,
52 const array& out);
53
54MTL::ComputePipelineState* get_scan_kernel(
56 const std::string& kernel_name,
57 bool reverse,
58 bool inclusive,
59 const std::string& reduce_type,
60 const array& in,
61 const array& out);
62
63MTL::ComputePipelineState* get_sort_kernel(
65 const std::string& kernel_name,
66 const array& in,
67 const array& out,
68 int bn,
69 int tn);
70
71MTL::ComputePipelineState* get_mb_sort_kernel(
73 const std::string& kernel_name,
74 const array& in,
75 const array& idx,
76 int bn,
77 int tn);
78
79MTL::ComputePipelineState* get_reduce_init_kernel(
81 const std::string& kernel_name,
82 const std::string& func_name,
83 const std::string& op_name,
84 const array& out);
85
86MTL::ComputePipelineState* get_reduce_kernel(
88 const std::string& kernel_name,
89 const std::string& func_name,
90 const std::string& op_name,
91 const array& in,
92 const array& out,
93 int ndim = -1,
94 int bm = -1,
95 int bn = -1);
96
97MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
99 const std::string& kernel_name,
100 const std::string& hash_name,
101 const metal::MTLFCList& func_consts,
102 const array& out,
103 bool transpose_a,
104 bool transpose_b,
105 int bm,
106 int bn,
107 int bk,
108 int wm,
109 int wn);
110
111MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
112 metal::Device& d,
113 const std::string& kernel_name,
114 const array& in,
115 const array& out,
116 bool transpose_a,
117 bool transpose_b,
118 int bm,
119 int bn,
120 int bk,
121 int wm,
122 int wn,
123 bool mn_aligned,
124 bool k_aligned);
125
126MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
127 metal::Device& d,
128 const std::string& kernel_name,
129 const array& in,
130 const array& out,
131 bool axbpy);
132
133MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
134 metal::Device& d,
135 const std::string& kernel_name,
136 const array& out,
137 const std::optional<array>& mask_out,
138 const std::optional<array>& mask_op,
139 bool transpose_a,
140 bool transpose_b,
141 int bm,
142 int bn,
143 int bk,
144 int wm,
145 int wn,
146 bool mn_aligned,
147 bool k_aligned);
148
149MTL::ComputePipelineState* get_steel_conv_kernel(
150 metal::Device& d,
151 const std::string& kernel_name,
152 const array& out,
153 int bm,
154 int bn,
155 int bk,
156 int wm,
157 int wn,
158 int n_channel_specialization,
159 bool small_filter);
160
161MTL::ComputePipelineState* get_gemv_masked_kernel(
162 metal::Device& d,
163 const std::string& kernel_name,
164 const array& out,
165 const std::optional<array>& mask_out,
166 const std::optional<array>& mask_op,
167 bool transpose_mat,
168 int bm,
169 int bn,
170 int sm,
171 int sn,
172 int tm,
173 int tn,
174 bool contiguous);
175
176MTL::ComputePipelineState* get_steel_conv_general_kernel(
177 metal::Device& d,
178 const std::string& kernel_name,
179 const array& out,
180 int bm,
181 int bn,
182 int bk,
183 int wm,
184 int wn);
185
186MTL::ComputePipelineState* get_fft_kernel(
187 metal::Device& d,
188 const std::string& kernel_name,
189 const std::string& hash_name,
190 const metal::MTLFCList& func_consts,
191 const std::string& template_def);
192
193MTL::ComputePipelineState* get_quantized_kernel(
194 metal::Device& d,
195 const std::string& kernel_name,
196 const std::string& template_def);
197
198// Create a GPU kernel template definition for JIT compilation
199template <typename... Args>
200std::string
201get_template_definition(std::string name, std::string func, Args... args) {
202 std::ostringstream s;
203 s << func << "<";
204 bool first = true;
205 auto add_arg = [&s, &first](const auto& arg) {
206 if (!first) {
207 s << ", ";
208 }
209 first = false;
210 s << arg;
211 };
212 (add_arg(args), ...);
213 s << ">";
214 return fmt::format(
215 "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
216 name,
217 s.str());
218}
219
220} // namespace mlx::core
Definition array.h:20
Definition device.h:131
Op op
Definition binary.h:129
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:38
Definition allocator.h:7
MTL::ComputePipelineState * get_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)
MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &out)
MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)
MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const array &in, const array &out, int ndim=-1, int bm=-1, int bn=-1)
MTL::ComputePipelineState * get_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)
MTL::ComputePipelineState * get_quantized_kernel(metal::Device &d, const std::string &kernel_name, const std::string &template_def)
std::string get_template_definition(std::string name, std::string func, Args... args)
Definition kernels.h:201
MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
MTL::ComputePipelineState * get_scan_kernel(metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_mb_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)
MTL::ComputePipelineState * get_unary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
Definition dtype.h:13