MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_math>
9
10typedef half float16_t;
11
13// Type limits utils
15
16template <typename U>
17struct Limits {
18 static const constant U max = metal::numeric_limits<U>::max();
19 static const constant U min = metal::numeric_limits<U>::min();
20 static const constant U finite_max = metal::numeric_limits<U>::max();
21 static const constant U finite_min = metal::numeric_limits<U>::min();
22};
23
24#define instantiate_default_limit(type) \
25 template <> \
26 struct Limits<type> { \
27 static constexpr constant type max = metal::numeric_limits<type>::max(); \
28 static constexpr constant type min = metal::numeric_limits<type>::min(); \
29 static constexpr constant type finite_max = \
30 metal::numeric_limits<type>::max(); \
31 static constexpr constant type finite_min = \
32 metal::numeric_limits<type>::min(); \
33 };
34
43
44#define instantiate_float_limit(type) \
45 template <> \
46 struct Limits<type> { \
47 static constexpr constant type max = \
48 metal::numeric_limits<type>::infinity(); \
49 static constexpr constant type min = \
50 -metal::numeric_limits<type>::infinity(); \
51 static constexpr constant type finite_max = \
52 metal::numeric_limits<type>::max(); \
53 static constexpr constant type finite_min = \
54 -metal::numeric_limits<type>::max(); \
55 };
56
60
61template <>
62struct Limits<bool> {
63 static constexpr constant bool max = true;
64 static constexpr constant bool min = false;
65};
66
67template <>
69 static constexpr constant complex64_t max = complex64_t(
70 metal::numeric_limits<float>::infinity(),
71 metal::numeric_limits<float>::infinity());
72 static constexpr constant complex64_t min = complex64_t(
73 -metal::numeric_limits<float>::infinity(),
74 -metal::numeric_limits<float>::infinity());
75};
76
78// Indexing utils
80
81#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
82
84// Single Array with generic dims
85
86template <typename stride_t>
87METAL_FUNC stride_t elem_to_loc(
88 uint elem,
89 device const int* shape,
90 device const stride_t* strides,
91 int ndim) {
92 stride_t loc = 0;
93 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
94 loc += (elem % shape[i]) * strides[i];
95 elem /= shape[i];
96 }
97 return loc;
98}
99
100template <typename stride_t>
101METAL_FUNC stride_t elem_to_loc(
102 uint elem,
103 constant const int* shape,
104 constant const stride_t* strides,
105 int ndim) {
106 stride_t loc = 0;
107 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
108 loc += (elem % shape[i]) * strides[i];
109 elem /= shape[i];
110 }
111 return loc;
112}
113
114template <typename stride_t>
115METAL_FUNC stride_t elem_to_loc(
116 stride_t elem,
117 device const int* shape,
118 device const stride_t* strides,
119 int ndim) {
120 stride_t loc = 0;
121 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
122 loc += (elem % shape[i]) * strides[i];
123 elem /= shape[i];
124 }
125 return loc;
126}
127
128template <typename stride_t>
129METAL_FUNC stride_t elem_to_loc(
130 stride_t elem,
131 constant const int* shape,
132 constant const stride_t* strides,
133 int ndim) {
134 stride_t loc = 0;
135 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
136 loc += (elem % shape[i]) * strides[i];
137 elem /= shape[i];
138 }
139 return loc;
140}
141
142// Non templated version to handle arbitrary dims
143template <typename stride_t>
144METAL_FUNC stride_t elem_to_loc(
145 uint3 elem,
146 constant const int* shape,
147 constant const stride_t* strides,
148 int ndim) {
149 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
150 for (int d = ndim - 3; d >= 0; --d) {
151 loc += (elem.z % shape[d]) * strides[d];
152 elem.z /= shape[d];
153 }
154 return loc;
155}
156
158// Single Array with fixed N dims
159
160template <typename stride_t>
161METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
162 return elem * stride;
163}
164
165template <typename stride_t>
166METAL_FUNC stride_t
167elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
168 return elem.x * strides[1] + elem.y * strides[0];
169}
170
171template <typename stride_t>
172METAL_FUNC stride_t
173elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
174 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
175}
176
177template <int NDIM>
178METAL_FUNC size_t elem_to_loc_nd(
179 uint elem,
180 device const int* shape,
181 device const size_t* strides) {
182 size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
183
185 for (int d = NDIM - 2; d >= 0; --d) {
186 elem /= shape[d + 1];
187 loc += (elem % shape[d]) * strides[d];
188 }
189
190 return loc;
191}
192
193template <int NDIM>
194METAL_FUNC size_t elem_to_loc_nd(
195 uint3 elem,
196 constant const int shape[NDIM],
197 constant const size_t strides[NDIM]) {
198 size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
199 for (int d = NDIM - 3; d >= 0; --d) {
200 loc += (elem.z % shape[d]) * strides[d];
201 elem.z /= shape[d];
202 }
203 return loc;
204}
205
206template <int NDIM>
207METAL_FUNC int64_t elem_to_loc_nd(
208 uint elem,
209 constant const int shape[NDIM],
210 constant const int64_t strides[NDIM]) {
211 int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
212
214 for (int d = NDIM - 2; d >= 0; --d) {
215 elem /= shape[d + 1];
216 loc += (elem % shape[d]) * strides[d];
217 }
218
219 return loc;
220}
221
222template <int NDIM>
223METAL_FUNC int64_t elem_to_loc_nd(
224 uint3 elem,
225 constant const int shape[NDIM],
226 constant const int64_t strides[NDIM]) {
227 int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
228 for (int d = NDIM - 3; d >= 0; --d) {
229 loc += (elem.z % shape[d]) * strides[d];
230 elem.z /= shape[d];
231 }
232 return loc;
233}
234
236// Multiple Arrays with generic dims
237
238METAL_FUNC uint2 elem_to_loc_2_nd(
239 uint3 elem,
240 constant const int* shape,
241 constant const size_t* a_strides,
242 constant const size_t* b_strides,
243 int ndim) {
244 uint2 loc = {
245 static_cast<uint>(
246 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
247 static_cast<uint>(
248 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
249 for (int d = ndim - 3; d >= 0; --d) {
250 uint l = elem.z % shape[d];
251 loc.x += l * a_strides[d];
252 loc.y += l * b_strides[d];
253 elem.z /= shape[d];
254 }
255 return loc;
256}
257
258METAL_FUNC uint3 elem_to_loc_3_nd(
259 uint3 elem,
260 constant const int* shape,
261 constant const size_t* a_strides,
262 constant const size_t* b_strides,
263 constant const size_t* c_strides,
264 int ndim) {
265 uint3 loc = {
266 static_cast<uint>(
267 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
268 static_cast<uint>(
269 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
270 static_cast<uint>(
271 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
272 for (int d = ndim - 3; d >= 0; --d) {
273 uint l = elem.z % shape[d];
274 loc.x += l * a_strides[d];
275 loc.y += l * b_strides[d];
276 loc.z += l * c_strides[d];
277 elem.z /= shape[d];
278 }
279 return loc;
280}
281
283// Multiple Arrays with fixed N dims
284
285template <int NDIM>
286METAL_FUNC uint2 elem_to_loc_2_nd(
287 uint3 elem,
288 constant const int shape[NDIM],
289 constant const size_t a_strides[NDIM],
290 constant const size_t b_strides[NDIM]) {
291 uint2 loc = {
292 static_cast<uint>(
293 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
294 static_cast<uint>(
295 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
296 for (int d = NDIM - 3; d >= 0; --d) {
297 uint l = elem.z % shape[d];
298 loc.x += l * a_strides[d];
299 loc.y += l * b_strides[d];
300 elem.z /= shape[d];
301 }
302 return loc;
303}
304
305template <int NDIM>
306METAL_FUNC uint3 elem_to_loc_3_nd(
307 uint3 elem,
308 constant const int shape[NDIM],
309 constant const size_t a_strides[NDIM],
310 constant const size_t b_strides[NDIM],
311 constant const size_t c_strides[NDIM]) {
312 uint3 loc = {
313 static_cast<uint>(
314 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
315 static_cast<uint>(
316 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
317 static_cast<uint>(
318 elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
319 for (int d = NDIM - 3; d >= 0; --d) {
320 uint l = elem.z % shape[d];
321 loc.x += l * a_strides[d];
322 loc.y += l * b_strides[d];
323 loc.z += l * c_strides[d];
324 elem.z /= shape[d];
325 }
326 return loc;
327}
328
330// Elem to loc in a loop utils
332
333template <int dim, typename offset_t = size_t>
336 offset_t offset{0};
337 int index{0};
338
339 void next(const constant int* shape, const constant size_t* strides) {
340 index++;
341 offset += strides[dim - 1];
342
343 if (index >= shape[dim - 1]) {
344 index = 0;
345 inner_looper.next(shape, strides);
346 offset = inner_looper.offset;
347 }
348 }
349
350 void next(int n, const constant int* shape, const constant size_t* strides) {
351 index += n;
352 offset += n * strides[dim - 1];
353
354 if (index >= shape[dim - 1]) {
355 int extra = index - shape[dim - 1];
356 index = 0;
357 inner_looper.next(shape, strides);
358 offset = inner_looper.offset;
359 if (extra > 0) {
360 next(extra, shape, strides);
361 }
362 }
363 }
364
365 offset_t
366 location(offset_t, const constant int*, const constant size_t*, int) {
367 return offset;
368 }
369};
370
371template <typename offset_t>
372struct looped_elem_to_loc<1, offset_t> {
373 offset_t offset{0};
374
375 void next(const constant int*, const constant size_t* strides) {
376 offset += strides[0];
377 }
378
379 void next(int n, const constant int*, const constant size_t* strides) {
380 offset += n * strides[0];
381 }
382
383 offset_t
384 location(offset_t, const constant int*, const constant size_t*, int) {
385 return offset;
386 }
387};
388
389template <typename offset_t>
390struct looped_elem_to_loc<0, offset_t> {
391 void next(const constant int*, const constant size_t*) {}
392 void next(int, const constant int*, const constant size_t*) {}
393
394 offset_t location(
395 offset_t idx,
396 const constant int* shape,
397 const constant size_t* strides,
398 int ndim) {
399 return elem_to_loc(idx, shape, strides, ndim);
400 }
401};
402
404// Calculation utils
406
408template <typename T, typename U>
409inline T ceildiv(T N, U M) {
410 return (N + M - 1) / M;
411}
412
413// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
414inline float log1p(float x) {
415 float xp1 = 1.0f + x;
416 if (xp1 == Limits<float>::max) {
417 return Limits<float>::max;
418 }
419 if (xp1 == 1.0f) {
420 return x;
421 }
422
423 return x * (metal::log(xp1) / (xp1 - 1.0f));
424}
425
427 float xp1 = 1.0f + static_cast<float>(x);
428 if (xp1 == Limits<float>::max) {
430 }
431 if (xp1 == 1.0f) {
432 return x;
433 }
434
435 return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
436}
437
439// SIMD shuffle ops
441
442inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
443 return as_type<uint64_t>(
444 metal::simd_shuffle_down(as_type<uint2>(data), delta));
445}
446
447inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
448 return as_type<int64_t>(
449 metal::simd_shuffle_down(as_type<uint2>(data), delta));
450}
451
452inline bool simd_shuffle_down(bool data, uint16_t delta) {
453 return simd_shuffle_down(static_cast<uint32_t>(data), delta);
454}
455
456inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
457 return complex64_t(
458 simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
459}
BufferHolder * next
Definition allocator.h:37
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:257
#define MLX_MTL_PRAGMA_UNROLL
Definition utils.h:81
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:161
#define instantiate_float_limit(type)
Definition utils.h:44
float log1p(float x)
Definition utils.h:414
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:173
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:87
METAL_FUNC uint2 elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:238
METAL_FUNC uint3 elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
Definition utils.h:258
T ceildiv(T N, U M)
Compute ceil((float)N/(float)M)
Definition utils.h:409
METAL_FUNC size_t elem_to_loc_nd(uint elem, device const int *shape, device const size_t *strides)
Definition utils.h:178
#define instantiate_default_limit(type)
Definition utils.h:24
half float16_t
Definition utils.h:10
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:167
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
Definition bf16.h:54
Definition utils.h:17
static const constant U max
Definition utils.h:18
static const constant U finite_max
Definition utils.h:20
static const constant U min
Definition utils.h:19
static const constant U finite_min
Definition utils.h:21
Definition complex.h:20
float imag
Definition complex.h:22
float real
Definition complex.h:21
void next(int, const constant int *, const constant size_t *)
Definition utils.h:392
offset_t location(offset_t idx, const constant int *shape, const constant size_t *strides, int ndim)
Definition utils.h:394
void next(const constant int *, const constant size_t *)
Definition utils.h:391
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:384
void next(const constant int *, const constant size_t *strides)
Definition utils.h:375
void next(int n, const constant int *, const constant size_t *strides)
Definition utils.h:379
Definition utils.h:334
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:339
offset_t offset
Definition utils.h:336
int index
Definition utils.h:337
looped_elem_to_loc< dim - 1, offset_t > inner_looper
Definition utils.h:335
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:366
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:350