MLX
Loading...
Searching...
No Matches
kernels.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#include "mlx/array.h"
5
6namespace mlx::core {
7
8MTL::ComputePipelineState* get_arange_kernel(
10 const std::string& kernel_name,
11 const array& out);
12
13MTL::ComputePipelineState* get_unary_kernel(
15 const std::string& kernel_name,
16 const array& out);
17
18MTL::ComputePipelineState* get_binary_kernel(
20 const std::string& kernel_name,
21 const array& in,
22 const array& out);
23
24MTL::ComputePipelineState* get_binary_two_kernel(
26 const std::string& kernel_name,
27 const array& in,
28 const array& out);
29
30MTL::ComputePipelineState* get_ternary_kernel(
32 const std::string& kernel_name,
33 const array& out);
34
35MTL::ComputePipelineState* get_copy_kernel(
37 const std::string& kernel_name,
38 const array& in,
39 const array& out);
40
41MTL::ComputePipelineState* get_softmax_kernel(
43 const std::string& kernel_name,
44 bool precise,
45 const array& out);
46
47MTL::ComputePipelineState* get_scan_kernel(
49 const std::string& kernel_name,
50 bool reverse,
51 bool inclusive,
52 const std::string& reduce_type,
53 const array& in,
54 const array& out);
55
56MTL::ComputePipelineState* get_sort_kernel(
58 const std::string& kernel_name,
59 const array& in,
60 const array& out,
61 int bn,
62 int tn);
63
64MTL::ComputePipelineState* get_mb_sort_kernel(
66 const std::string& kernel_name,
67 const array& in,
68 const array& idx,
69 int bn,
70 int tn);
71
72MTL::ComputePipelineState* get_reduce_init_kernel(
74 const std::string& kernel_name,
75 const array& out);
76
77MTL::ComputePipelineState* get_reduce_kernel(
79 const std::string& kernel_name,
80 const std::string& op_name,
81 const array& in,
82 const array& out);
83
84MTL::ComputePipelineState* get_steel_gemm_fused_kernel(
86 const std::string& kernel_name,
87 const std::string& hash_name,
88 const metal::MTLFCList& func_consts,
89 const array& out,
90 bool transpose_a,
91 bool transpose_b,
92 int bm,
93 int bn,
94 int bk,
95 int wm,
96 int wn);
97
98MTL::ComputePipelineState* get_steel_gemm_splitk_kernel(
100 const std::string& kernel_name,
101 const array& in,
102 const array& out,
103 bool transpose_a,
104 bool transpose_b,
105 int bm,
106 int bn,
107 int bk,
108 int wm,
109 int wn,
110 bool mn_aligned,
111 bool k_aligned);
112
113MTL::ComputePipelineState* get_steel_gemm_splitk_accum_kernel(
114 metal::Device& d,
115 const std::string& kernel_name,
116 const array& in,
117 const array& out,
118 bool axbpy);
119
120MTL::ComputePipelineState* get_steel_gemm_masked_kernel(
121 metal::Device& d,
122 const std::string& kernel_name,
123 const array& out,
124 const std::optional<array>& mask_out,
125 const std::optional<array>& mask_op,
126 bool transpose_a,
127 bool transpose_b,
128 int bm,
129 int bn,
130 int bk,
131 int wm,
132 int wn,
133 bool mn_aligned,
134 bool k_aligned);
135
136MTL::ComputePipelineState* get_steel_conv_kernel(
137 metal::Device& d,
138 const std::string& kernel_name,
139 const array& out,
140 int bm,
141 int bn,
142 int bk,
143 int wm,
144 int wn,
145 int n_channel_specialization,
146 bool small_filter);
147
148MTL::ComputePipelineState* get_steel_conv_general_kernel(
149 metal::Device& d,
150 const std::string& kernel_name,
151 const array& out,
152 int bm,
153 int bn,
154 int bk,
155 int wm,
156 int wn);
157
158MTL::ComputePipelineState* get_fft_kernel(
159 metal::Device& d,
160 const std::string& kernel_name,
161 const std::string& hash_name,
162 const int tg_mem_size,
163 const std::string& in_type,
164 const std::string& out_type,
165 int step,
166 bool real,
167 const metal::MTLFCList& func_consts);
168
169} // namespace mlx::core
Definition array.h:20
Definition device.h:117
std::vector< std::tuple< const void *, MTL::DataType, NS::UInteger > > MTLFCList
Definition device.h:36
Definition allocator.h:7
MTL::ComputePipelineState * get_copy_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_binary_two_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_accum_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool axbpy)
MTL::ComputePipelineState * get_ternary_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_softmax_kernel(metal::Device &d, const std::string &kernel_name, bool precise, const array &out)
MTL::ComputePipelineState * get_fft_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const int tg_mem_size, const std::string &in_type, const std::string &out_type, int step, bool real, const metal::MTLFCList &func_consts)
MTL::ComputePipelineState * get_reduce_init_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_reduce_kernel(metal::Device &d, const std::string &kernel_name, const std::string &op_name, const array &in, const array &out)
MTL::ComputePipelineState * get_arange_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, int bn, int tn)
MTL::ComputePipelineState * get_steel_gemm_fused_kernel(metal::Device &d, const std::string &kernel_name, const std::string &hash_name, const metal::MTLFCList &func_consts, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_steel_gemm_masked_kernel(metal::Device &d, const std::string &kernel_name, const array &out, const std::optional< array > &mask_out, const std::optional< array > &mask_op, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_steel_conv_general_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn)
MTL::ComputePipelineState * get_unary_kernel(metal::Device &d, const std::string &kernel_name, const array &out)
MTL::ComputePipelineState * get_steel_conv_kernel(metal::Device &d, const std::string &kernel_name, const array &out, int bm, int bn, int bk, int wm, int wn, int n_channel_specialization, bool small_filter)
MTL::ComputePipelineState * get_binary_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out)
MTL::ComputePipelineState * get_scan_kernel(metal::Device &d, const std::string &kernel_name, bool reverse, bool inclusive, const std::string &reduce_type, const array &in, const array &out)
MTL::ComputePipelineState * get_steel_gemm_splitk_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &out, bool transpose_a, bool transpose_b, int bm, int bn, int bk, int wm, int wn, bool mn_aligned, bool k_aligned)
MTL::ComputePipelineState * get_mb_sort_kernel(metal::Device &d, const std::string &kernel_name, const array &in, const array &idx, int bn, int tn)