MLX
 
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_math>
6
7// The correct bf16.h is included based on the metal version
8// by giving the correct path to -I during compilation
9// e.g. mlx/backend/metal/kernels/metal_3_0/ for Metal 3.0
10#include "bf16.h"
11
15
16typedef half float16_t;
17
19// Type limits utils
21
22template <typename U>
23struct Limits {
24 static const constant U max = metal::numeric_limits<U>::max();
25 static const constant U min = metal::numeric_limits<U>::min();
26 static const constant U finite_max = metal::numeric_limits<U>::max();
27 static const constant U finite_min = metal::numeric_limits<U>::min();
28};
29
30#define instantiate_default_limit(type) \
31 template <> \
32 struct Limits<type> { \
33 static constexpr constant type max = metal::numeric_limits<type>::max(); \
34 static constexpr constant type min = metal::numeric_limits<type>::min(); \
35 static constexpr constant type finite_max = \
36 metal::numeric_limits<type>::max(); \
37 static constexpr constant type finite_min = \
38 metal::numeric_limits<type>::min(); \
39 };
40
49
50#define instantiate_float_limit(type) \
51 template <> \
52 struct Limits<type> { \
53 static constexpr constant type max = \
54 metal::numeric_limits<type>::infinity(); \
55 static constexpr constant type min = \
56 -metal::numeric_limits<type>::infinity(); \
57 static constexpr constant type finite_max = \
58 metal::numeric_limits<type>::max(); \
59 static constexpr constant type finite_min = \
60 -metal::numeric_limits<type>::max(); \
61 };
62
66
67template <>
68struct Limits<bool> {
69 static constexpr constant bool max = true;
70 static constexpr constant bool min = false;
71};
72
73template <>
75 static constexpr constant complex64_t max = complex64_t(
76 metal::numeric_limits<float>::infinity(),
77 metal::numeric_limits<float>::infinity());
78 static constexpr constant complex64_t min = complex64_t(
79 -metal::numeric_limits<float>::infinity(),
80 -metal::numeric_limits<float>::infinity());
81};
82
84// Indexing utils
86
87#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
88
90// Single Array with generic dims
91
92template <typename IdxT = int64_t>
93METAL_FUNC IdxT elem_to_loc(
94 IdxT elem,
95 constant const int* shape,
96 constant const int64_t* strides,
97 int ndim) {
98 IdxT loc = 0;
99 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
100 loc += (elem % shape[i]) * IdxT(strides[i]);
101 elem /= shape[i];
102 }
103 return loc;
104}
105
106// Non templated version to handle arbitrary dims
107template <typename IdxT = int64_t>
108METAL_FUNC IdxT elem_to_loc(
109 uint3 elem,
110 constant const int* shape,
111 constant const int64_t* strides,
112 int ndim) {
113 IdxT loc =
114 elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
115 for (int d = ndim - 3; d >= 0; --d) {
116 loc += (elem.z % shape[d]) * IdxT(strides[d]);
117 elem.z /= shape[d];
118 }
119 return loc;
120}
121
123// Single Array with fixed N dims
124
125template <typename IdxT = int64_t>
126METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) {
127 return elem * IdxT(stride);
128}
129
130template <typename IdxT = int64_t>
131METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) {
132 return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
133}
134
135template <typename IdxT = int64_t>
136METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) {
137 return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
138 elem.z * IdxT(strides[0]);
139}
140
142// Multiple Arrays with generic dims
143
144template <typename IdxT = int64_t>
145METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
146 uint3 elem,
147 constant const int* shape,
148 constant const int64_t* a_strides,
149 constant const int64_t* b_strides,
150 int ndim) {
151 vec<IdxT, 2> loc = {
152 IdxT(
153 elem.x * IdxT(a_strides[ndim - 1]) +
154 IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
155 IdxT(
156 elem.x * IdxT(b_strides[ndim - 1]) +
157 elem.y * IdxT(b_strides[ndim - 2]))};
158 for (int d = ndim - 3; d >= 0; --d) {
159 uint l = elem.z % shape[d];
160 loc.x += l * IdxT(a_strides[d]);
161 loc.y += l * IdxT(b_strides[d]);
162 elem.z /= shape[d];
163 }
164 return loc;
165}
166
167template <typename IdxT = int64_t>
168METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
169 uint3 elem,
170 constant const int* shape,
171 constant const int64_t* a_strides,
172 constant const int64_t* b_strides,
173 constant const int64_t* c_strides,
174 int ndim) {
175 vec<IdxT, 3> loc = {
176 IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
177 IdxT(elem.y * IdxT(a_strides[ndim - 2])),
178 IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
179 IdxT(elem.y * IdxT(b_strides[ndim - 2])),
180 IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
181 IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
182 for (int d = ndim - 3; d >= 0; --d) {
183 uint l = elem.z % shape[d];
184 loc.x += l * IdxT(a_strides[d]);
185 loc.y += l * IdxT(b_strides[d]);
186 loc.z += l * IdxT(c_strides[d]);
187 elem.z /= shape[d];
188 }
189 return loc;
190}
191
193// Elem to loc in a loop utils
195
196template <int DIM, typename OffsetT = size_t, bool General = true>
198 int dim;
199 LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
200 OffsetT offset{0};
201 int index{0};
202
204
205 void next(const constant int* shape, const constant int64_t* strides) {
206 if (dim == 0) {
207 return;
208 }
209 index++;
210 offset += OffsetT(strides[dim - 1]);
211 if (index >= shape[dim - 1]) {
212 index = 0;
213 inner_looper.next(shape, strides);
214 offset = inner_looper.offset;
215 }
216 }
217
218 void next(int n, const constant int* shape, const constant int64_t* strides) {
219 if (dim == 0) {
220 return;
221 }
222 index += n;
223 offset += n * OffsetT(strides[dim - 1]);
224
225 if (index >= shape[dim - 1]) {
226 int extra = index - shape[dim - 1];
227 if (extra >= shape[dim - 1]) {
228 inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
229 extra = extra % shape[dim - 1];
230 } else {
231 inner_looper.next(shape, strides);
232 }
233 index = 0;
234 offset = inner_looper.offset;
235 if (extra > 0) {
236 next(extra, shape, strides);
237 }
238 }
239 }
240
241 OffsetT location() {
242 return offset;
243 }
244};
245
246template <typename OffsetT>
247struct LoopedElemToLoc<1, OffsetT, true> {
248 int dim;
249 OffsetT offset{0};
250 uint index{0};
251
253
254 void next(const constant int* shape, const constant int64_t* strides) {
255 index++;
256 if (dim > 1) {
257 offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
258 } else {
259 offset += OffsetT(strides[0]);
260 }
261 }
262
263 void next(int n, const constant int* shape, const constant int64_t* strides) {
264 index += n;
265 if (dim > 1) {
266 offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
267 } else {
268 offset = index * OffsetT(strides[0]);
269 }
270 }
271
272 OffsetT location() {
273 return offset;
274 }
275};
276
277template <typename OffsetT>
278struct LoopedElemToLoc<1, OffsetT, false> {
279 OffsetT offset{0};
280
282
283 void next(const constant int*, const constant int64_t* strides) {
284 offset += OffsetT(strides[0]);
285 }
286
287 void next(int n, const constant int*, const constant int64_t* strides) {
288 offset += n * OffsetT(strides[0]);
289 }
290
291 OffsetT location() {
292 return offset;
293 }
294};
295
297// Calculation utils
299
301template <typename T, typename U>
302inline T ceildiv(T N, U M) {
303 return (N + M - 1) / M;
304}
305
306// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
307inline float log1p(float x) {
308 float xp1 = 1.0f + x;
309 if (xp1 == Limits<float>::max) {
310 return Limits<float>::max;
311 }
312 if (xp1 == 1.0f) {
313 return x;
314 }
315
316 return x * (metal::log(xp1) / (xp1 - 1.0f));
317}
318
320 float xp1 = 1.0f + static_cast<float>(x);
321 if (xp1 == Limits<float>::max) {
323 }
324 if (xp1 == 1.0f) {
325 return x;
326 }
327
328 return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
329}
330
332// SIMD shuffle ops
334
335inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
336 return as_type<uint64_t>(
337 metal::simd_shuffle_down(as_type<uint2>(data), delta));
338}
339
340inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
341 return as_type<int64_t>(
342 metal::simd_shuffle_down(as_type<uint2>(data), delta));
343}
344
345inline bool simd_shuffle_down(bool data, uint16_t delta) {
346 return simd_shuffle_down(static_cast<uint32_t>(data), delta);
347}
348
349inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
350 return complex64_t(
351 simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
352}
353
354inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
355 return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
356}
357
358inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
359 return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
360}
361
362inline bool simd_shuffle_up(bool data, uint16_t delta) {
363 return simd_shuffle_up(static_cast<uint32_t>(data), delta);
364}
365
366inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) {
367 return complex64_t(
368 simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta));
369}
370
371inline uint64_t
372simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
373 return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
374 as_type<uint2>(data), as_type<uint2>(filling), delta));
375}
376
377inline int64_t
378simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
379 return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
380 as_type<uint2>(data), as_type<uint2>(filling), delta));
381}
382
383inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
385 static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
386}
387
389 complex64_t data,
390 complex64_t filling,
391 uint16_t delta) {
392 return complex64_t(
393 simd_shuffle_and_fill_up(data.real, filling.real, delta),
394 simd_shuffle_and_fill_up(data.imag, filling.imag, delta));
395}
396
397inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
398 return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
399}
400
401inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
402 return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
403}
404
405inline bool simd_shuffle(bool data, uint16_t lane) {
406 return simd_shuffle(static_cast<uint32_t>(data), lane);
407}
408
409inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
410 return complex64_t(
411 simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
412}
413
414// std::conditional is not included with Metal
415template <bool condition, typename T, typename U>
417 using type = U;
418};
419
420template <typename T, typename U>
421struct ConditionalType<true, T, U> {
422 using type = T;
423};
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:251
#define instantiate_float_limit(type)
Definition utils.h:50
float log1p(float x)
Definition utils.h:307
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
METAL_FUNC vec< IdxT, 3 > elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, constant const int64_t *c_strides, int ndim)
Definition utils.h:168
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t &stride)
Definition utils.h:126
T ceildiv(T N, U M)
Compute ceil((float)N/(float)M)
Definition utils.h:302
METAL_FUNC vec< IdxT, 2 > elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const int64_t *a_strides, constant const int64_t *b_strides, int ndim)
Definition utils.h:145
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2])
Definition utils.h:131
#define instantiate_default_limit(type)
Definition utils.h:30
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3])
Definition utils.h:136
half float16_t
Definition utils.h:16
METAL_FUNC bfloat16_t simd_shuffle_and_fill_up(bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
Definition bf16_math.h:377
METAL_FUNC bfloat16_t simd_shuffle(bfloat16_t data, ushort simd_lane_id)
Definition bf16_math.h:377
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:377
METAL_FUNC bfloat16_t simd_shuffle_up(bfloat16_t data, ushort delta)
Definition bf16_math.h:377
T type
Definition utils.h:422
Definition utils.h:416
U type
Definition utils.h:417
static constexpr constant bool min
Definition utils.h:70
static constexpr constant bool max
Definition utils.h:69
static constexpr constant complex64_t min
Definition utils.h:78
static constexpr constant complex64_t max
Definition utils.h:75
Definition utils.h:23
static const constant U max
Definition utils.h:24
static const constant U finite_max
Definition utils.h:26
static const constant U min
Definition utils.h:25
static const constant U finite_min
Definition utils.h:27
LoopedElemToLoc(int)
Definition utils.h:281
void next(int n, const constant int *, const constant int64_t *strides)
Definition utils.h:287
OffsetT location()
Definition utils.h:291
void next(const constant int *, const constant int64_t *strides)
Definition utils.h:283
OffsetT offset
Definition utils.h:279
uint index
Definition utils.h:250
OffsetT offset
Definition utils.h:249
OffsetT location()
Definition utils.h:272
void next(int n, const constant int *shape, const constant int64_t *strides)
Definition utils.h:263
int dim
Definition utils.h:248
LoopedElemToLoc(int dim)
Definition utils.h:252
void next(const constant int *shape, const constant int64_t *strides)
Definition utils.h:254
LoopedElemToLoc(int dim)
Definition utils.h:203
LoopedElemToLoc< DIM - 1, OffsetT, General > inner_looper
Definition utils.h:199
void next(const constant int *shape, const constant int64_t *strides)
Definition utils.h:205
OffsetT location()
Definition utils.h:241
int index
Definition utils.h:201
OffsetT offset
Definition utils.h:200
void next(int n, const constant int *shape, const constant int64_t *strides)
Definition utils.h:218
int dim
Definition utils.h:198
Definition complex.h:20
float imag
Definition complex.h:22
float real
Definition complex.h:21