MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_math>
9
10typedef half float16_t;
11
13// Type limits utils
15
16template <typename U>
17struct Limits {
18 static const constant U max = metal::numeric_limits<U>::max();
19 static const constant U min = metal::numeric_limits<U>::min();
20 static const constant U finite_max = metal::numeric_limits<U>::max();
21 static const constant U finite_min = metal::numeric_limits<U>::min();
22};
23
24#define instantiate_default_limit(type) \
25 template <> \
26 struct Limits<type> { \
27 static constexpr constant type max = metal::numeric_limits<type>::max(); \
28 static constexpr constant type min = metal::numeric_limits<type>::min(); \
29 static constexpr constant type finite_max = \
30 metal::numeric_limits<type>::max(); \
31 static constexpr constant type finite_min = \
32 metal::numeric_limits<type>::min(); \
33 };
34
43
44#define instantiate_float_limit(type) \
45 template <> \
46 struct Limits<type> { \
47 static constexpr constant type max = \
48 metal::numeric_limits<type>::infinity(); \
49 static constexpr constant type min = \
50 -metal::numeric_limits<type>::infinity(); \
51 static constexpr constant type finite_max = \
52 metal::numeric_limits<type>::max(); \
53 static constexpr constant type finite_min = \
54 -metal::numeric_limits<type>::max(); \
55 };
56
60
61template <>
62struct Limits<bool> {
63 static constexpr constant bool max = true;
64 static constexpr constant bool min = false;
65};
66
67template <>
69 static constexpr constant complex64_t max = complex64_t(
70 metal::numeric_limits<float>::infinity(),
71 metal::numeric_limits<float>::infinity());
72 static constexpr constant complex64_t min = complex64_t(
73 -metal::numeric_limits<float>::infinity(),
74 -metal::numeric_limits<float>::infinity());
75};
76
78// Indexing utils
80
81#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
82
84// Single Array with generic dims
85
86template <typename stride_t>
87METAL_FUNC stride_t elem_to_loc(
88 uint elem,
89 constant const int* shape,
90 constant const stride_t* strides,
91 int ndim) {
92 stride_t loc = 0;
93 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
94 loc += (elem % shape[i]) * strides[i];
95 elem /= shape[i];
96 }
97 return loc;
98}
99
100template <typename stride_t>
101METAL_FUNC stride_t elem_to_loc(
102 stride_t elem,
103 constant const int* shape,
104 constant const stride_t* strides,
105 int ndim) {
106 stride_t loc = 0;
107 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
108 loc += (elem % shape[i]) * strides[i];
109 elem /= shape[i];
110 }
111 return loc;
112}
113
114// Non templated version to handle arbitrary dims
115template <typename stride_t>
116METAL_FUNC stride_t elem_to_loc(
117 uint3 elem,
118 constant const int* shape,
119 constant const stride_t* strides,
120 int ndim) {
121 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
122 for (int d = ndim - 3; d >= 0; --d) {
123 loc += (elem.z % shape[d]) * strides[d];
124 elem.z /= shape[d];
125 }
126 return loc;
127}
128
130// Single Array with fixed N dims
131
132template <typename stride_t>
133METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
134 return elem * stride;
135}
136
137template <typename stride_t>
138METAL_FUNC stride_t
139elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
140 return elem.x * strides[1] + elem.y * strides[0];
141}
142
143template <typename stride_t>
144METAL_FUNC stride_t
145elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
146 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
147}
148
150// Multiple Arrays with generic dims
151
152template <typename stride_t>
153METAL_FUNC ulong2 elem_to_loc_2_nd(
154 uint3 elem,
155 constant const int* shape,
156 constant const stride_t* a_strides,
157 constant const stride_t* b_strides,
158 int ndim) {
159 ulong2 loc = {
160 ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
161 ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
162 for (int d = ndim - 3; d >= 0; --d) {
163 uint l = elem.z % shape[d];
164 loc.x += l * a_strides[d];
165 loc.y += l * b_strides[d];
166 elem.z /= shape[d];
167 }
168 return loc;
169}
170
171METAL_FUNC ulong3 elem_to_loc_3_nd(
172 uint3 elem,
173 constant const int* shape,
174 constant const size_t* a_strides,
175 constant const size_t* b_strides,
176 constant const size_t* c_strides,
177 int ndim) {
178 ulong3 loc = {
179 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
180 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
181 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
182 for (int d = ndim - 3; d >= 0; --d) {
183 uint l = elem.z % shape[d];
184 loc.x += l * a_strides[d];
185 loc.y += l * b_strides[d];
186 loc.z += l * c_strides[d];
187 elem.z /= shape[d];
188 }
189 return loc;
190}
191
193// Elem to loc in a loop utils
195
196template <int dim, typename offset_t = size_t>
199 offset_t offset{0};
200 int index{0};
201
202 void next(const constant int* shape, const constant size_t* strides) {
203 index++;
204 offset += strides[dim - 1];
205
206 if (index >= shape[dim - 1]) {
207 index = 0;
208 inner_looper.next(shape, strides);
209 offset = inner_looper.offset;
210 }
211 }
212
213 void next(int n, const constant int* shape, const constant size_t* strides) {
214 index += n;
215 offset += n * strides[dim - 1];
216
217 if (index >= shape[dim - 1]) {
218 int extra = index - shape[dim - 1];
219 index = 0;
220 inner_looper.next(shape, strides);
221 offset = inner_looper.offset;
222 if (extra > 0) {
223 next(extra, shape, strides);
224 }
225 }
226 }
227
228 offset_t
229 location(offset_t, const constant int*, const constant size_t*, int) {
230 return offset;
231 }
232};
233
234template <typename offset_t>
235struct looped_elem_to_loc<1, offset_t> {
236 offset_t offset{0};
237
238 void next(const constant int*, const constant size_t* strides) {
239 offset += strides[0];
240 }
241
242 void next(int n, const constant int*, const constant size_t* strides) {
243 offset += n * strides[0];
244 }
245
246 offset_t
247 location(offset_t, const constant int*, const constant size_t*, int) {
248 return offset;
249 }
250};
251
252template <typename offset_t>
253struct looped_elem_to_loc<0, offset_t> {
254 void next(const constant int*, const constant size_t*) {}
255 void next(int, const constant int*, const constant size_t*) {}
256
257 offset_t location(
258 offset_t idx,
259 const constant int* shape,
260 const constant size_t* strides,
261 int ndim) {
262 return elem_to_loc(idx, shape, strides, ndim);
263 }
264};
265
267// Calculation utils
269
271template <typename T, typename U>
272inline T ceildiv(T N, U M) {
273 return (N + M - 1) / M;
274}
275
276// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
277inline float log1p(float x) {
278 float xp1 = 1.0f + x;
279 if (xp1 == Limits<float>::max) {
280 return Limits<float>::max;
281 }
282 if (xp1 == 1.0f) {
283 return x;
284 }
285
286 return x * (metal::log(xp1) / (xp1 - 1.0f));
287}
288
290 float xp1 = 1.0f + static_cast<float>(x);
291 if (xp1 == Limits<float>::max) {
293 }
294 if (xp1 == 1.0f) {
295 return x;
296 }
297
298 return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
299}
300
302// SIMD shuffle ops
304
305inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
306 return as_type<uint64_t>(
307 metal::simd_shuffle_down(as_type<uint2>(data), delta));
308}
309
310inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
311 return as_type<int64_t>(
312 metal::simd_shuffle_down(as_type<uint2>(data), delta));
313}
314
315inline bool simd_shuffle_down(bool data, uint16_t delta) {
316 return simd_shuffle_down(static_cast<uint32_t>(data), delta);
317}
318
319inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) {
320 return complex64_t(
321 simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta));
322}
323
324inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
325 return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
326}
327
328inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
329 return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
330}
331
332inline bool simd_shuffle_up(bool data, uint16_t delta) {
333 return simd_shuffle_up(static_cast<uint32_t>(data), delta);
334}
335
336inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) {
337 return complex64_t(
338 simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta));
339}
340
341inline uint64_t
342simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
343 return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
344 as_type<uint2>(data), as_type<uint2>(filling), delta));
345}
346
347inline int64_t
348simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
349 return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
350 as_type<uint2>(data), as_type<uint2>(filling), delta));
351}
352
353inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
355 static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
356}
357
359 complex64_t data,
360 complex64_t filling,
361 uint16_t delta) {
362 return complex64_t(
363 simd_shuffle_and_fill_up(data.real, filling.real, delta),
364 simd_shuffle_and_fill_up(data.imag, filling.imag, delta));
365}
366
367inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
368 return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
369}
370
371inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
372 return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
373}
374
375inline bool simd_shuffle(bool data, uint16_t lane) {
376 return simd_shuffle(static_cast<uint32_t>(data), lane);
377}
378
379inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
380 return complex64_t(
381 simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
382}
BufferHolder * next
Definition allocator.h:38
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:257
METAL_FUNC ulong2 elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const stride_t *a_strides, constant const stride_t *b_strides, int ndim)
Definition utils.h:153
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:133
#define instantiate_float_limit(type)
Definition utils.h:44
float log1p(float x)
Definition utils.h:277
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:145
METAL_FUNC ulong3 elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
Definition utils.h:171
T ceildiv(T N, U M)
Compute ceil((float)N/(float)M)
Definition utils.h:272
METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
Definition utils.h:87
#define instantiate_default_limit(type)
Definition utils.h:24
half float16_t
Definition utils.h:10
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:139
METAL_FUNC bfloat16_t simd_shuffle_and_fill_up(bfloat16_t data, bfloat16_t filling_data, ushort delta, ushort modulo)
Definition bf16_math.h:391
METAL_FUNC bfloat16_t simd_shuffle(bfloat16_t data, ushort simd_lane_id)
Definition bf16_math.h:391
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
METAL_FUNC bfloat16_t simd_shuffle_up(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
std::vector< ptrdiff_t > stride_t
Definition pocketfft.h:103
Definition bf16.h:54
Definition utils.h:17
static const constant U max
Definition utils.h:18
static const constant U finite_max
Definition utils.h:20
static const constant U min
Definition utils.h:19
static const constant U finite_min
Definition utils.h:21
Definition complex.h:20
float imag
Definition complex.h:22
float real
Definition complex.h:21
void next(int, const constant int *, const constant size_t *)
Definition utils.h:255
offset_t location(offset_t idx, const constant int *shape, const constant size_t *strides, int ndim)
Definition utils.h:257
void next(const constant int *, const constant size_t *)
Definition utils.h:254
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:247
void next(const constant int *, const constant size_t *strides)
Definition utils.h:238
void next(int n, const constant int *, const constant size_t *strides)
Definition utils.h:242
Definition utils.h:197
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:202
offset_t offset
Definition utils.h:199
int index
Definition utils.h:200
looped_elem_to_loc< dim - 1, offset_t > inner_looper
Definition utils.h:198
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:229
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:213