MLX
 
Loading...
Searching...
No Matches
kernels.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
5#include <fmt/format.h>
6
7#include "mlx/array.h"
9
10namespace mlx::core {
11
12MTL::ComputePipelineState* get_arange_kernel(
14 const std::string& kernel_name,
15 const array& out);
16
17MTL::ComputePipelineState* get_unary_kernel(
19 const std::string& kernel_name,
20 Dtype in_type,
21 Dtype out_type,
22 const std::string op);
23
24MTL::ComputePipelineState* get_binary_kernel(
26 const std::string& kernel_name,
27 Dtype in_type,
28 Dtype out_type,
29 const std::string op);
30
31MTL::ComputePipelineState* get_binary_two_kernel(
33 const std::string& kernel_name,
34 Dtype in_type,
35 Dtype out_type,
36 const std::string op);
37
38MTL::ComputePipelineState* get_ternary_kernel(
40 const std::string& kernel_name,
41 Dtype type,
42 const std::string op);
43
44MTL::ComputePipelineState* get_copy_kernel(
46 const std::string& kernel_name,
47 const array& in,
48 const array& out);
49
50MTL::ComputePipelineState* get_dynamic_copy_kernel(
52 const std::string& kernel_name,
53 const array& in,
54 const array& out);
55
56MTL::ComputePipelineState* get_softmax_kernel(
58 const std::string& kernel_name,
59 bool precise,
60 const array& out);
61
62MTL::ComputePipelineState* get_scan_kernel(
64 const std::string& kernel_name,
65 bool reverse,
66 bool inclusive,
67 const std::string& reduce_type,
68 const array& in,
69 const array& out);
70
71MTL::ComputePipelineState* get_sort_kernel(
73 const std::string& kernel_name,
74 const array& in,
75 const array& out,
76 int bn,
77 int tn);
78
79MTL::ComputePipelineState* get_mb_sort_kernel(
81 const std::string& kernel_name,
82 const array& in,
83 const array& idx,
84 int bn,
85 int tn);
86
87MTL::ComputePipelineState* get_reduce_init_kernel(
89 const std::string& kernel_name,
90 const std::string& func_name,
91 const std::string& op_name,
92 const Dtype& out_type);
93
94MTL::ComputePipelineState* get_reduce_kernel(
96 const std::string& kernel_name,
97 const std::string& func_name,
98 const std::string& op_name,
99 const Dtype& in_type,
100 const Dtype& out_type,
101 const std::string& idx_t,
102 int ndim = -1,
103 int bm = -1,
104 int bn = -1);
105
106MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
107 metal::Device& d,
108 const std::string& kernel_name,
109 const std::string& hash_name,
110 const metal::MTLFCList& func_consts,
111 const array& out,
112 bool transpose_a,
113 bool transpose_b,
114 int bm,
115 int bn,
116 int bk,
117 int wm,
118 int wn);
119
120MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
121 metal::Device& d,
122 const std::string& kernel_name,
123 const array& in,
124 const array& out,
125 bool transpose_a,
126 bool transpose_b,
127 int bm,
128 int bn,
129 int bk,
130 int wm,
131 int wn,
132 bool mn_aligned,
133 bool k_aligned);
134
135MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
136 metal::Device& d,
137 const std::string& kernel_name,
138 const array& in,
139 const array& out,
140 bool axbpy);
141
142MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
143 metal::Device& d,
144 const std::string& kernel_name,
145 const array& out,
146 const std::optional<array>& mask_out,
147 const std::optional<array>& mask_op,
148 bool transpose_a,
149 bool transpose_b,
150 int bm,
151 int bn,
152 int bk,
153 int wm,
154 int wn,
155 bool mn_aligned,
156 bool k_aligned);
157
158MTL::ComputePipelineState* get_steel_conv_kernel(
159 metal::Device& d,
160 const std::string& kernel_name,
161 const array& out,
162 int bm,
163 int bn,
164 int bk,
165 int wm,
166 int wn,
167 int n_channel_specialization,
168 bool small_filter);
169
170MTL::ComputePipelineState* get_gemv_masked_kernel(
171 metal::Device& d,
172 const std::string& kernel_name,
173 const array& out,
174 const std::optional<array>& mask_out,
175 const std::optional<array>& mask_op,
176 bool transpose_mat,
177 int bm,
178 int bn,
179 int sm,
180 int sn,
181 int tm,
182 int tn,
183 bool contiguous);
184
185MTL::ComputePipelineState* get_steel_conv_general_kernel(
186 metal::Device& d,
187 const std::string& kernel_name,
188 const array& out,
189 int bm,
190 int bn,
191 int bk,
192 int wm,
193 int wn);
194
195MTL::ComputePipelineState* get_fft_kernel(
196 metal::Device& d,
197 const std::string& kernel_name,
198 const std::string& hash_name,
199 const metal::MTLFCList& func_consts,
200 const std::string& template_def);
201
202MTL::ComputePipelineState* get_quantized_kernel(
203 metal::Device& d,
204 const std::string& kernel_name,
205 const std::string& template_def);
206
207// Create a GPU kernel template definition for JIT compilation
208template <typename... Args>
209std::string
210get_template_definition(std::string name, std::string func, Args... args) {
211 std::ostringstream s;
212 s << func << "<";
213 bool first = true;
214 auto add_arg = [&s, &first](const auto& arg) {
215 if (!first) {
216 s << ", ";
217 }
218 first = false;
219 s << arg;
220 };
221 (add_arg(args), ...);
222 s << ">";
223 return fmt::format(
224 "\ntemplate [[host_name(\"{0}\")]] [[kernel]] decltype({1}) {1};\n",
225 name,
226 s.str());
227}
228
229} // namespace mlx::core
Definition array.h:24
Definition device.h:165
array contiguous(const array &a, bool allow_col_major=false, StreamOrDevice s={})
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:38
Definition allocator.h:7
MTL::ComputePipelineState * get_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const Dtype &in_type, const Dtype &out_type, const std::string &idx_t, int ndim=-1, int bm=-1, int bn=-1)
MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)
MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)
MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)
MTL::ComputePipelineState * get_quantized_kernel(metal::Device &d, const std::string &kernel_name, const std::string &template_def)
std::string get_template_definition(std::string name, std::string func, Args... args)
Definition kernels.h:210
MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
std::vector< array > Args
Definition export.h:11
MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
MTL::ComputePipelineState * get_dynamic_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const std::string &func_name, const std::string &op_name, const Dtype &out_type)
MTL::ComputePipelineState * get_scan_kernel(metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_mb_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)
MTL::ComputePipelineState * get_unary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
Definition dtype.h:13