MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_math>
8
10// Type limits utils
12
13template <typename U>
14struct Limits {
15 static const constant U max = metal::numeric_limits<U>::max();
16 static const constant U min = metal::numeric_limits<U>::min();
17 static const constant U finite_max = metal::numeric_limits<U>::max();
18 static const constant U finite_min = metal::numeric_limits<U>::min();
19};
20
21#define instantiate_default_limit(type) \
22 template <> \
23 struct Limits<type> { \
24 static constexpr constant type max = metal::numeric_limits<type>::max(); \
25 static constexpr constant type min = metal::numeric_limits<type>::min(); \
26 static constexpr constant type finite_max = \
27 metal::numeric_limits<type>::max(); \
28 static constexpr constant type finite_min = \
29 metal::numeric_limits<type>::min(); \
30 };
31
40
41#define instantiate_float_limit(type) \
42 template <> \
43 struct Limits<type> { \
44 static constexpr constant type max = \
45 metal::numeric_limits<type>::infinity(); \
46 static constexpr constant type min = \
47 -metal::numeric_limits<type>::infinity(); \
48 static constexpr constant type finite_max = \
49 metal::numeric_limits<type>::max(); \
50 static constexpr constant type finite_min = \
51 -metal::numeric_limits<type>::max(); \
52 };
53
57
58template <>
59struct Limits<bool> {
60 static constexpr constant bool max = true;
61 static constexpr constant bool min = false;
62};
63
65// Indexing utils
67
68#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
69
71// Single Array with generic dims
72
73template <typename stride_t>
74METAL_FUNC stride_t elem_to_loc(
75 uint elem,
76 device const int* shape,
77 device const stride_t* strides,
78 int ndim) {
79 stride_t loc = 0;
80 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
81 loc += (elem % shape[i]) * strides[i];
82 elem /= shape[i];
83 }
84 return loc;
85}
86
87template <typename stride_t>
88METAL_FUNC stride_t elem_to_loc(
89 uint elem,
90 constant const int* shape,
91 constant const stride_t* strides,
92 int ndim) {
93 stride_t loc = 0;
94 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
95 loc += (elem % shape[i]) * strides[i];
96 elem /= shape[i];
97 }
98 return loc;
99}
100
101// Non templated version to handle arbitrary dims
102template <typename stride_t>
103METAL_FUNC stride_t elem_to_loc(
104 uint3 elem,
105 constant const int* shape,
106 constant const stride_t* strides,
107 int ndim) {
108 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
109 for (int d = ndim - 3; d >= 0; --d) {
110 loc += (elem.z % shape[d]) * strides[d];
111 elem.z /= shape[d];
112 }
113 return loc;
114}
115
117// Single Array with fixed N dims
118
119template <typename stride_t>
120METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
121 return elem * stride;
122}
123
124template <typename stride_t>
125METAL_FUNC stride_t
126elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
127 return elem.x * strides[1] + elem.y * strides[0];
128}
129
130template <typename stride_t>
131METAL_FUNC stride_t
132elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
133 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
134}
135
136template <int NDIM>
137METAL_FUNC size_t elem_to_loc_nd(
138 uint elem,
139 device const int* shape,
140 device const size_t* strides) {
141 size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
142
144 for (int d = NDIM - 2; d >= 0; --d) {
145 elem /= shape[d + 1];
146 loc += (elem % shape[d]) * strides[d];
147 }
148
149 return loc;
150}
151
152template <int NDIM>
153METAL_FUNC size_t elem_to_loc_nd(
154 uint3 elem,
155 constant const int shape[NDIM],
156 constant const size_t strides[NDIM]) {
157 size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
158 for (int d = NDIM - 3; d >= 0; --d) {
159 loc += (elem.z % shape[d]) * strides[d];
160 elem.z /= shape[d];
161 }
162 return loc;
163}
164
165template <int NDIM>
166METAL_FUNC int64_t elem_to_loc_nd(
167 uint elem,
168 constant const int shape[NDIM],
169 constant const int64_t strides[NDIM]) {
170 int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
171
173 for (int d = NDIM - 2; d >= 0; --d) {
174 elem /= shape[d + 1];
175 loc += (elem % shape[d]) * strides[d];
176 }
177
178 return loc;
179}
180
181template <int NDIM>
182METAL_FUNC int64_t elem_to_loc_nd(
183 uint3 elem,
184 constant const int shape[NDIM],
185 constant const int64_t strides[NDIM]) {
186 int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
187 for (int d = NDIM - 3; d >= 0; --d) {
188 loc += (elem.z % shape[d]) * strides[d];
189 elem.z /= shape[d];
190 }
191 return loc;
192}
193
195// Multiple Arrays with generic dims
196
197METAL_FUNC uint2 elem_to_loc_2_nd(
198 uint3 elem,
199 constant const int* shape,
200 constant const size_t* a_strides,
201 constant const size_t* b_strides,
202 int ndim) {
203 uint2 loc = {
204 static_cast<uint>(
205 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
206 static_cast<uint>(
207 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
208 for (int d = ndim - 3; d >= 0; --d) {
209 uint l = elem.z % shape[d];
210 loc.x += l * a_strides[d];
211 loc.y += l * b_strides[d];
212 elem.z /= shape[d];
213 }
214 return loc;
215}
216
217METAL_FUNC uint3 elem_to_loc_3_nd(
218 uint3 elem,
219 constant const int* shape,
220 constant const size_t* a_strides,
221 constant const size_t* b_strides,
222 constant const size_t* c_strides,
223 int ndim) {
224 uint3 loc = {
225 static_cast<uint>(
226 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
227 static_cast<uint>(
228 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
229 static_cast<uint>(
230 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
231 for (int d = ndim - 3; d >= 0; --d) {
232 uint l = elem.z % shape[d];
233 loc.x += l * a_strides[d];
234 loc.y += l * b_strides[d];
235 loc.z += l * c_strides[d];
236 elem.z /= shape[d];
237 }
238 return loc;
239}
240
242// Multiple Arrays with fixed N dims
243
244template <int NDIM>
245METAL_FUNC uint2 elem_to_loc_2_nd(
246 uint3 elem,
247 constant const int shape[NDIM],
248 constant const size_t a_strides[NDIM],
249 constant const size_t b_strides[NDIM]) {
250 uint2 loc = {
251 static_cast<uint>(
252 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
253 static_cast<uint>(
254 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
255 for (int d = NDIM - 3; d >= 0; --d) {
256 uint l = elem.z % shape[d];
257 loc.x += l * a_strides[d];
258 loc.y += l * b_strides[d];
259 elem.z /= shape[d];
260 }
261 return loc;
262}
263
264template <int NDIM>
265METAL_FUNC uint3 elem_to_loc_3_nd(
266 uint3 elem,
267 constant const int shape[NDIM],
268 constant const size_t a_strides[NDIM],
269 constant const size_t b_strides[NDIM],
270 constant const size_t c_strides[NDIM]) {
271 uint3 loc = {
272 static_cast<uint>(
273 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
274 static_cast<uint>(
275 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
276 static_cast<uint>(
277 elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
278 for (int d = NDIM - 3; d >= 0; --d) {
279 uint l = elem.z % shape[d];
280 loc.x += l * a_strides[d];
281 loc.y += l * b_strides[d];
282 loc.z += l * c_strides[d];
283 elem.z /= shape[d];
284 }
285 return loc;
286}
287
289// Calculation utils
291
293inline size_t ceildiv(size_t N, size_t M) {
294 return (N + M - 1) / M;
295}
296
297// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
298inline float log1p(float x) {
299 float xp1 = 1.0f + x;
300 if (xp1 == Limits<float>::max) {
301 return Limits<float>::max;
302 }
303 if (xp1 == 1.0f) {
304 return x;
305 }
306
307 return x * (metal::log(xp1) / (xp1 - 1.0f));
308}
309
311 float xp1 = 1.0f + static_cast<float>(x);
312 if (xp1 == Limits<float>::max) {
314 }
315 if (xp1 == 1.0f) {
316 return x;
317 }
318
319 return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
320}
321
323// SIMD shuffle ops
325
326inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
327 return as_type<uint64_t>(
328 metal::simd_shuffle_down(as_type<uint2>(data), delta));
329}
330
331inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
332 return as_type<int64_t>(
333 metal::simd_shuffle_down(as_type<uint2>(data), delta));
334}
335
336inline bool simd_shuffle_down(bool data, uint16_t delta) {
337 return simd_shuffle_down(static_cast<uint32_t>(data), delta);
338}
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:257
#define MLX_MTL_PRAGMA_UNROLL
Definition utils.h:68
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:120
#define instantiate_float_limit(type)
Definition utils.h:41
float log1p(float x)
Definition utils.h:298
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:132
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:74
METAL_FUNC uint2 elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:197
size_t ceildiv(size_t N, size_t M)
Compute ceil((float)N/(float)M)
Definition utils.h:293
METAL_FUNC uint3 elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
Definition utils.h:217
METAL_FUNC size_t elem_to_loc_nd(uint elem, device const int *shape, device const size_t *strides)
Definition utils.h:137
#define instantiate_default_limit(type)
Definition utils.h:21
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:126
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
Definition bf16.h:54
Definition utils.h:14
static const constant U max
Definition utils.h:15
static const constant U finite_max
Definition utils.h:17
static const constant U min
Definition utils.h:16
static const constant U finite_min
Definition utils.h:18