MLX
Loading...
Searching...
No Matches
kernels.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include <fmt/format.h>
4
5#include "mlx/array.h"
7
8namespace mlx::core {
9
10MTL::ComputePipelineState* get_arange_kernel(
12 const std::string& kernel_name,
13 const array& out);
14
15MTL::ComputePipelineState* get_unary_kernel(
17 const std::string& kernel_name,
18 Dtype out_type,
19 const std::string op);
20
21MTL::ComputePipelineState* get_binary_kernel(
23 const std::string& kernel_name,
24 Dtype in_type,
25 Dtype out_type,
26 const std::string op);
27
28MTL::ComputePipelineState* get_binary_two_kernel(
30 const std::string& kernel_name,
31 Dtype in_type,
32 Dtype out_type,
33 const std::string op);
34
35MTL::ComputePipelineState* get_ternary_kernel(
37 const std::string& kernel_name,
38 Dtype type,
39 const std::string op);
40
41MTL::ComputePipelineState* get_copy_kernel(
43 const std::string& kernel_name,
44 const array& in,
45 const array& out);
46
47MTL::ComputePipelineState* get_softmax_kernel(
49 const std::string& kernel_name,
50 bool precise,
51 const array& out);
52
53MTL::ComputePipelineState* get_scan_kernel(
55 const std::string& kernel_name,
56 bool reverse,
57 bool inclusive,
58 const std::string& reduce_type,
59 const array& in,
60 const array& out);
61
62MTL::ComputePipelineState* get_sort_kernel(
64 const std::string& kernel_name,
65 const array& in,
66 const array& out,
67 int bn,
68 int tn);
69
70MTL::ComputePipelineState* get_mb_sort_kernel(
72 const std::string& kernel_name,
73 const array& in,
74 const array& idx,
75 int bn,
76 int tn);
77
78MTL::ComputePipelineState* get_reduce_init_kernel(
80 const std::string& kernel_name,
81 const array& out);
82
83MTL::ComputePipelineState* get_reduce_kernel(
85 const std::string& kernel_name,
86 const std::string& op_name,
87 const array& in,
88 const array& out);
89
90MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
92 const std::string& kernel_name,
93 const std::string& hash_name,
94 const metal::MTLFCList& func_consts,
95 const array& out,
96 bool transpose_a,
97 bool transpose_b,
98 int bm,
99 int bn,
100 int bk,
101 int wm,
102 int wn);
103
104MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
105 metal::Device& d,
106 const std::string& kernel_name,
107 const array& in,
108 const array& out,
109 bool transpose_a,
110 bool transpose_b,
111 int bm,
112 int bn,
113 int bk,
114 int wm,
115 int wn,
116 bool mn_aligned,
117 bool k_aligned);
118
119MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
120 metal::Device& d,
121 const std::string& kernel_name,
122 const array& in,
123 const array& out,
124 bool axbpy);
125
126MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
127 metal::Device& d,
128 const std::string& kernel_name,
129 const array& out,
130 const std::optional<array>& mask_out,
131 const std::optional<array>& mask_op,
132 bool transpose_a,
133 bool transpose_b,
134 int bm,
135 int bn,
136 int bk,
137 int wm,
138 int wn,
139 bool mn_aligned,
140 bool k_aligned);
141
142MTL::ComputePipelineState* get_steel_conv_kernel(
143 metal::Device& d,
144 const std::string& kernel_name,
145 const array& out,
146 int bm,
147 int bn,
148 int bk,
149 int wm,
150 int wn,
151 int n_channel_specialization,
152 bool small_filter);
153
154MTL::ComputePipelineState* get_gemv_masked_kernel(
155 metal::Device& d,
156 const std::string& kernel_name,
157 const array& out,
158 const std::optional<array>& mask_out,
159 const std::optional<array>& mask_op,
160 bool transpose_mat,
161 int bm,
162 int bn,
163 int sm,
164 int sn,
165 int tm,
166 int tn,
167 bool contiguous);
168
169MTL::ComputePipelineState* get_steel_conv_general_kernel(
170 metal::Device& d,
171 const std::string& kernel_name,
172 const array& out,
173 int bm,
174 int bn,
175 int bk,
176 int wm,
177 int wn);
178
179MTL::ComputePipelineState* get_fft_kernel(
180 metal::Device& d,
181 const std::string& kernel_name,
182 const std::string& hash_name,
183 const metal::MTLFCList& func_consts,
184 const std::string& template_def);
185
186MTL::ComputePipelineState* get_quantized_kernel(
187 metal::Device& d,
188 const std::string& kernel_name,
189 const std::string& template_def);
190
191// Create a GPU kernel template definition for JIT compilation
192template <typename... Args>
193std::string
194get_template_definition(std::string name, std::string func, Args... args) {
195 std::ostringstream s;
196 s << func << "<";
197 bool first = true;
198 auto add_arg = [&s, &first](const auto& arg) {
199 if (!first) {
200 s << ", ";
201 }
202 first = false;
203 s << arg;
204 };
205 (add_arg(args), ...);
206 s << ">";
207 std::string base_string = R"(
208template [[host_name("{0}")]] [[kernel]] decltype({1}) {1};
209 )";
210 return fmt::format(base_string, name, s.str());
211}
212
213} // namespace mlx::core
Definition array.h:20
Definition device.h:66
Op op
Definition binary.h:141
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:17
Definition allocator.h:7
MTL::ComputePipelineState * get_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_unary_kernel(metal::Device &d, const std::string &kernel_name, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const std::string &template_def)
MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, Dtype in_type, Dtype out_type, const std::string op)
MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, Dtype type, const std::string op)
MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &op_name, const array &in, const array &out)
MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_gemv_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_mat, int bm, int bn, int sm, int sn, int tm, int tn, bool contiguous)
MTL::ComputePipelineState * get_quantized_kernel(metal::Device &d, const std::string &kernel_name, const std::string &template_def)
std::string get_template_definition(std::string name, std::string func, Args... args)
Definition kernels.h:194
MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
MTL::ComputePipelineState * get_scan_kernel(metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_mb_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)
Definition dtype.h:13