MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_math>
6
7// The correct bf16.h is included based on the metal version
8// by giving the correct path to -I during compilation
9// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
10#include "bf16.h"
11
15
16typedef half float16_t;
17
19// Type limits utils
21
22template <typename U>
23struct Limits {
24 static const constant U max = metal::numeric_limits<U>::max();
25 static const constant U min = metal::numeric_limits<U>::min();
26 static const constant U finite_max = metal::numeric_limits<U>::max();
27 static const constant U finite_min = metal::numeric_limits<U>::min();
28};
29
30#define instantiate_default_limit(type) \
31 template <> \
32 struct Limits<type> { \
33 static constexpr constant type max = metal::numeric_limits<type>::max(); \
34 static constexpr constant type min = metal::numeric_limits<type>::min(); \
35 static constexpr constant type finite_max = \
36 metal::numeric_limits<type>::max(); \
37 static constexpr constant type finite_min = \
38 metal::numeric_limits<type>::min(); \
39 };
40
49
50#define instantiate_float_limit(type) \
51 template <> \
52 struct Limits<type> { \
53 static constexpr constant type max = \
54 metal::numeric_limits<type>::infinity(); \
55 static constexpr constant type min = \
56 -metal::numeric_limits<type>::infinity(); \
57 static constexpr constant type finite_max = \
58 metal::numeric_limits<type>::max(); \
59 static constexpr constant type finite_min = \
60 -metal::numeric_limits<type>::max(); \
61 };
62
66
67template <>
68struct Limits<bool> {
69 static constexpr constant bool max = true;
70 static constexpr constant bool min = false;
71};
72
73template <>
75 static constexpr constant complex64_t max = complex64_t(
76 metal::numeric_limits<float>::infinity(),
77 metal::numeric_limits<float>::infinity());
78 static constexpr constant complex64_t min = complex64_t(
79 -metal::numeric_limits<float>::infinity(),
80 -metal::numeric_limits<float>::infinity());
81};
82
84// Indexing utils
86
87#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
88
90// Single Array with generic dims
91
92template <typename StrideT, typename IdxT = StrideT>
93METAL_FUNC IdxT elem_to_loc(
94 uint elem,
95 constant const int* shape,
96 constant const StrideT* strides,
97 int ndim) {
98 IdxT loc = 0;
99 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
100 loc += (elem % shape[i]) * IdxT(strides[i]);
101 elem /= shape[i];
102 }
103 return loc;
104}
105
106template <typename StrideT, typename IdxT = StrideT>
107METAL_FUNC IdxT elem_to_loc(
108 StrideT elem,
109 constant const int* shape,
110 constant const StrideT* strides,
111 int ndim) {
112 IdxT loc = 0;
113 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
114 loc += (elem % shape[i]) * IdxT(strides[i]);
115 elem /= shape[i];
116 }
117 return loc;
118}
119
120// Non templated version to handle arbitrary dims
121template <typename StrideT, typename IdxT = StrideT>
122METAL_FUNC IdxT elem_to_loc(
123 uint3 elem,
124 constant const int* shape,
125 constant const StrideT* strides,
126 int ndim) {
127 IdxT loc =
128 elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
129 for (int d = ndim - 3; d >= 0; --d) {
130 loc += (elem.z % shape[d]) * IdxT(strides[d]);
131 elem.z /= shape[d];
132 }
133 return loc;
134}
135
137// Single Array with fixed N dims
138
139template <typename StrideT, typename IdxT = StrideT>
140METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT& stride) {
141 return elem * IdxT(stride);
142}
143
144template <typename StrideT, typename IdxT = StrideT>
145METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2]) {
146 return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
147}
148
149template <typename StrideT, typename IdxT = StrideT>
150METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3]) {
151 return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
152 elem.z * IdxT(strides[0]);
153}
154
156// Multiple Arrays with generic dims
157
158template <typename StrideT, typename IdxT = StrideT>
159METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
160 uint3 elem,
161 constant const int* shape,
162 constant const StrideT* a_strides,
163 constant const StrideT* b_strides,
164 int ndim) {
165 vec<IdxT, 2> loc = {
166 IdxT(
167 elem.x * IdxT(a_strides[ndim - 1]) +
168 IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
169 IdxT(
170 elem.x * IdxT(b_strides[ndim - 1]) +
171 elem.y * IdxT(b_strides[ndim - 2]))};
172 for (int d = ndim - 3; d >= 0; --d) {
173 uint l = elem.z % shape[d];
174 loc.x += l * IdxT(a_strides[d]);
175 loc.y += l * IdxT(b_strides[d]);
176 elem.z /= shape[d];
177 }
178 return loc;
179}
180
181template <typename IdxT = size_t>
182METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
183 uint3 elem,
184 constant const int* shape,
185 constant const size_t* a_strides,
186 constant const size_t* b_strides,
187 constant const size_t* c_strides,
188 int ndim) {
189 vec<IdxT, 3> loc = {
190 elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
191 elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
192 elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
193 for (int d = ndim - 3; d >= 0; --d) {
194 uint l = elem.z % shape[d];
195 loc.x += l * IdxT(a_strides[d]);
196 loc.y += l * IdxT(b_strides[d]);
197 loc.z += l * IdxT(c_strides[d]);
198 elem.z /= shape[d];
199 }
200 return loc;
201}
202
204// Elem to loc in a loop utils
206
207template <int DIM, typename OffsetT = size_t, bool General = true>
209 int dim;
210 LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
211 OffsetT offset{0};
212 int index{0};
213
215
216 void next(const constant int* shape, const constant size_t* strides) {
217 if (dim == 0) {
218 return;
219 }
220 index++;
221 offset += OffsetT(strides[dim - 1]);
222 if (index >= shape[dim - 1]) {
223 index = 0;
224 inner_looper.next(shape, strides);
225 offset = inner_looper.offset;
226 }
227 }
228
229 void next(int n, const constant int* shape, const constant size_t* strides) {
230 if (dim == 0) {
231 return;
232 }
233 index += n;
234 offset += n * OffsetT(strides[dim - 1]);
235
236 if (index >= shape[dim - 1]) {
237 int extra = index - shape[dim - 1];
238 if (extra >= shape[dim - 1]) {
239 inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
240 extra = extra % shape[dim - 1];
241 } else {
242 inner_looper.next(shape, strides);
243 }
244 index = 0;
245 offset = inner_looper.offset;
246 if (extra > 0) {
247 next(extra, shape, strides);
248 }
249 }
250 }
251
252 OffsetT location() {
253 return offset;
254 }
255};
256
257template <typename OffsetT>
258struct LoopedElemToLoc<1, OffsetT, true> {
259 int dim;
260 OffsetT offset{0};
261 uint index{0};
262
264
265 void next(const constant int* shape, const constant size_t* strides) {
266 index++;
267 if (dim > 1) {
268 offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
269 } else {
270 offset += OffsetT(strides[0]);
271 }
272 }
273
274 void next(int n, const constant int* shape, const constant size_t* strides) {
275 index += n;
276 if (dim > 1) {
277 offset = elem_to_loc<size_t, OffsetT>(index, shape, strides, dim);
278 } else {
279 offset = index * OffsetT(strides[0]);
280 }
281 }
282
283 OffsetT location() {
284 return offset;
285 }
286};
287
288template <typename OffsetT>
289struct LoopedElemToLoc<1, OffsetT, false> {
290 OffsetT offset{0};
291
293
294 void next(const constant int*, const constant size_t* strides) {
295 offset += OffsetT(strides[0]);
296 }
297
298 void next(int n, const constant int*, const constant size_t* strides) {
299 offset += n * OffsetT(strides[0]);
300 }
301
302 OffsetT location() {
303 return offset;
304 }
305};
306
308// Calculation utils
310
312template <typename T, typename U>
313inline T ceildiv(T N, U M) {
314 return (N + M - 1) / M;
315}
316
317// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
318inline float log1p(float x) {
319 float xp1 = 1.0f + x;
320 if (xp1 == Limits<float>::max) {
321 return Limits<float>::max;
322 }
323 if (xp1 == 1.0f) {
324 return x;
325 }
326
327 return x * (metal::log(xp1) / (xp1 - 1.0f));
328}
329
331 float xp1 = 1.0f + static_cast<float>(x);
332 if (xp1 == Limits<float>::max) {
334 }
335 if (xp1 == 1.0f) {
336 return x;
337 }
338
339 return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
340}
341
343// SIMD shuffle ops
345
346inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
347 return as_type<uint64_t>(
348 metal::simd_shuffle_down(as_type<uint2>(data), delta));
349}
350
351inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
352 return as_type<int64_t>(
353 metal::simd_shuffle_down(as_type<uint2>(data), delta));
354}
355
356inline bool simd_shuffle_down(bool data, uint16_t delta) {
357 return simd_shuffle_down(static_cast<uint32_t>(data), delta);
358}
359
360inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
361 return complex64_t(
362 simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
363}
364
365inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
366 return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
367}
368
369inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
370 return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
371}
372
373inline bool simd_shuffle_up(bool data, uint16_t delta) {
374 return simd_shuffle_up(static_cast<uint32_t>(data), delta);
375}
376
377inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) {
378 return complex64_t(
379 simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta));
380}
381
382inline uint64_t
383simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
384 return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
385 as_type<uint2>(data), as_type<uint2>(filling), delta));
386}
387
388inline int64_t
389simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
390 return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
391 as_type<uint2>(data), as_type<uint2>(filling), delta));
392}
393
394inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
396 static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
397}
398
400 complex64_t data,
401 complex64_t filling,
402 uint16_t delta) {
403 return complex64_t(
404 simd_shuffle_and_fill_up(data.real, filling.real, delta),
405 simd_shuffle_and_fill_up(data.imag, filling.imag, delta));
406}
407
408inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
409 return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
410}
411
412inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
413 return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
414}
415
416inline bool simd_shuffle(bool data, uint16_t lane) {
417 return simd_shuffle(static_cast<uint32_t>(data), lane);
418}
419
420inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
421 return complex64_t(
422 simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
423}
424
425// std::conditional is not included with Metal
426template <bool condition, typename T, typename U>
428 using type = U;
429};
430
431template <typename T, typename U>
432struct ConditionalType<true, T, U> {
433 using type = T;
434};
BufferHolder * next
Definition allocator.h:38
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:251
#define instantiate_float_limit(type)
Definition utils.h:50
METAL_FUNC IdxT elem_to_loc(uint elem, constant const int *shape, constant const StrideT *strides, int ndim)
Definition utils.h:93
float log1p(float x)
Definition utils.h:318
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const StrideT strides[2])
Definition utils.h:145
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const StrideT strides[3])
Definition utils.h:150
METAL_FUNC vec< IdxT, 3 > elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
Definition utils.h:182
METAL_FUNC vec< IdxT, 2 > elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const StrideT *a_strides, constant const StrideT *b_strides, int ndim)
Definition utils.h:159
T ceildiv(T N, U M)
Compute ceil((float)N/(float)M)
Definition utils.h:313
#define instantiate_default_limit(type)
Definition utils.h:30
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const StrideT &stride)
Definition utils.h:140
half float16_t
Definition utils.h:16
METAL_FUNC bfloat16_t simd_shuffle_and_fill_up(bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
Definition bf16_math.h:377
METAL_FUNC bfloat16_t simd_shuffle(bfloat16_t data, ushort simd_lane_id)
Definition bf16_math.h:377
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:377
METAL_FUNC bfloat16_t simd_shuffle_up(bfloat16_t data, ushort delta)
Definition bf16_math.h:377
Definition bf16.h:48
T type
Definition utils.h:433
Definition utils.h:427
U type
Definition utils.h:428
Definition utils.h:23
static const constant U max
Definition utils.h:24
static const constant U finite_max
Definition utils.h:26
static const constant U min
Definition utils.h:25
static const constant U finite_min
Definition utils.h:27
void next(const constant int *, const constant size_t *strides)
Definition utils.h:294
LoopedElemToLoc(int)
Definition utils.h:292
OffsetT location()
Definition utils.h:302
void next(int n, const constant int *, const constant size_t *strides)
Definition utils.h:298
OffsetT location()
Definition utils.h:283
int dim
Definition utils.h:259
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:274
LoopedElemToLoc(int dim)
Definition utils.h:263
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:265
Definition utils.h:208
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:216
LoopedElemToLoc(int dim)
Definition utils.h:214
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:229
LoopedElemToLoc< DIM - 1, OffsetT, General > inner_looper
Definition utils.h:210
OffsetT location()
Definition utils.h:252
int index
Definition utils.h:212
OffsetT offset
Definition utils.h:211
int dim
Definition utils.h:209
Definition complex.h:20
float imag
Definition complex.h:22
float real
Definition complex.h:21