MLX
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3#pragma once
4
5#include <metal_math>
9
10typedef half float16_t;
11
13// Type limits utils
15
16template <typename U>
17struct Limits {
18 static const constant U max = metal::numeric_limits<U>::max();
19 static const constant U min = metal::numeric_limits<U>::min();
20 static const constant U finite_max = metal::numeric_limits<U>::max();
21 static const constant U finite_min = metal::numeric_limits<U>::min();
22};
23
24#define instantiate_default_limit(type) \
25 template <> \
26 struct Limits<type> { \
27 static constexpr constant type max = metal::numeric_limits<type>::max(); \
28 static constexpr constant type min = metal::numeric_limits<type>::min(); \
29 static constexpr constant type finite_max = \
30 metal::numeric_limits<type>::max(); \
31 static constexpr constant type finite_min = \
32 metal::numeric_limits<type>::min(); \
33 };
34
43
44#define instantiate_float_limit(type) \
45 template <> \
46 struct Limits<type> { \
47 static constexpr constant type max = \
48 metal::numeric_limits<type>::infinity(); \
49 static constexpr constant type min = \
50 -metal::numeric_limits<type>::infinity(); \
51 static constexpr constant type finite_max = \
52 metal::numeric_limits<type>::max(); \
53 static constexpr constant type finite_min = \
54 -metal::numeric_limits<type>::max(); \
55 };
56
60
61template <>
62struct Limits<bool> {
63 static constexpr constant bool max = true;
64 static constexpr constant bool min = false;
65};
66
68// Indexing utils
70
71#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
72
74// Single Array with generic dims
75
76template <typename stride_t>
77METAL_FUNC stride_t elem_to_loc(
78 uint elem,
79 device const int* shape,
80 device const stride_t* strides,
81 int ndim) {
82 stride_t loc = 0;
83 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
84 loc += (elem % shape[i]) * strides[i];
85 elem /= shape[i];
86 }
87 return loc;
88}
89
90template <typename stride_t>
91METAL_FUNC stride_t elem_to_loc(
92 uint elem,
93 constant const int* shape,
94 constant const stride_t* strides,
95 int ndim) {
96 stride_t loc = 0;
97 for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
98 loc += (elem % shape[i]) * strides[i];
99 elem /= shape[i];
100 }
101 return loc;
102}
103
104// Non templated version to handle arbitrary dims
105template <typename stride_t>
106METAL_FUNC stride_t elem_to_loc(
107 uint3 elem,
108 constant const int* shape,
109 constant const stride_t* strides,
110 int ndim) {
111 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
112 for (int d = ndim - 3; d >= 0; --d) {
113 loc += (elem.z % shape[d]) * strides[d];
114 elem.z /= shape[d];
115 }
116 return loc;
117}
118
120// Single Array with fixed N dims
121
122template <typename stride_t>
123METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t& stride) {
124 return elem * stride;
125}
126
127template <typename stride_t>
128METAL_FUNC stride_t
129elem_to_loc_2(uint2 elem, constant const stride_t strides[2]) {
130 return elem.x * strides[1] + elem.y * strides[0];
131}
132
133template <typename stride_t>
134METAL_FUNC stride_t
135elem_to_loc_3(uint3 elem, constant const stride_t strides[3]) {
136 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
137}
138
139template <int NDIM>
140METAL_FUNC size_t elem_to_loc_nd(
141 uint elem,
142 device const int* shape,
143 device const size_t* strides) {
144 size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
145
147 for (int d = NDIM - 2; d >= 0; --d) {
148 elem /= shape[d + 1];
149 loc += (elem % shape[d]) * strides[d];
150 }
151
152 return loc;
153}
154
155template <int NDIM>
156METAL_FUNC size_t elem_to_loc_nd(
157 uint3 elem,
158 constant const int shape[NDIM],
159 constant const size_t strides[NDIM]) {
160 size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
161 for (int d = NDIM - 3; d >= 0; --d) {
162 loc += (elem.z % shape[d]) * strides[d];
163 elem.z /= shape[d];
164 }
165 return loc;
166}
167
168template <int NDIM>
169METAL_FUNC int64_t elem_to_loc_nd(
170 uint elem,
171 constant const int shape[NDIM],
172 constant const int64_t strides[NDIM]) {
173 int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
174
176 for (int d = NDIM - 2; d >= 0; --d) {
177 elem /= shape[d + 1];
178 loc += (elem % shape[d]) * strides[d];
179 }
180
181 return loc;
182}
183
184template <int NDIM>
185METAL_FUNC int64_t elem_to_loc_nd(
186 uint3 elem,
187 constant const int shape[NDIM],
188 constant const int64_t strides[NDIM]) {
189 int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
190 for (int d = NDIM - 3; d >= 0; --d) {
191 loc += (elem.z % shape[d]) * strides[d];
192 elem.z /= shape[d];
193 }
194 return loc;
195}
196
198// Multiple Arrays with generic dims
199
200METAL_FUNC uint2 elem_to_loc_2_nd(
201 uint3 elem,
202 constant const int* shape,
203 constant const size_t* a_strides,
204 constant const size_t* b_strides,
205 int ndim) {
206 uint2 loc = {
207 static_cast<uint>(
208 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
209 static_cast<uint>(
210 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
211 for (int d = ndim - 3; d >= 0; --d) {
212 uint l = elem.z % shape[d];
213 loc.x += l * a_strides[d];
214 loc.y += l * b_strides[d];
215 elem.z /= shape[d];
216 }
217 return loc;
218}
219
220METAL_FUNC uint3 elem_to_loc_3_nd(
221 uint3 elem,
222 constant const int* shape,
223 constant const size_t* a_strides,
224 constant const size_t* b_strides,
225 constant const size_t* c_strides,
226 int ndim) {
227 uint3 loc = {
228 static_cast<uint>(
229 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
230 static_cast<uint>(
231 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
232 static_cast<uint>(
233 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
234 for (int d = ndim - 3; d >= 0; --d) {
235 uint l = elem.z % shape[d];
236 loc.x += l * a_strides[d];
237 loc.y += l * b_strides[d];
238 loc.z += l * c_strides[d];
239 elem.z /= shape[d];
240 }
241 return loc;
242}
243
245// Multiple Arrays with fixed N dims
246
247template <int NDIM>
248METAL_FUNC uint2 elem_to_loc_2_nd(
249 uint3 elem,
250 constant const int shape[NDIM],
251 constant const size_t a_strides[NDIM],
252 constant const size_t b_strides[NDIM]) {
253 uint2 loc = {
254 static_cast<uint>(
255 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
256 static_cast<uint>(
257 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
258 for (int d = NDIM - 3; d >= 0; --d) {
259 uint l = elem.z % shape[d];
260 loc.x += l * a_strides[d];
261 loc.y += l * b_strides[d];
262 elem.z /= shape[d];
263 }
264 return loc;
265}
266
267template <int NDIM>
268METAL_FUNC uint3 elem_to_loc_3_nd(
269 uint3 elem,
270 constant const int shape[NDIM],
271 constant const size_t a_strides[NDIM],
272 constant const size_t b_strides[NDIM],
273 constant const size_t c_strides[NDIM]) {
274 uint3 loc = {
275 static_cast<uint>(
276 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
277 static_cast<uint>(
278 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
279 static_cast<uint>(
280 elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
281 for (int d = NDIM - 3; d >= 0; --d) {
282 uint l = elem.z % shape[d];
283 loc.x += l * a_strides[d];
284 loc.y += l * b_strides[d];
285 loc.z += l * c_strides[d];
286 elem.z /= shape[d];
287 }
288 return loc;
289}
290
292// Calculation utils
294
296inline size_t ceildiv(size_t N, size_t M) {
297 return (N + M - 1) / M;
298}
299
300// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
301inline float log1p(float x) {
302 float xp1 = 1.0f + x;
303 if (xp1 == Limits<float>::max) {
304 return Limits<float>::max;
305 }
306 if (xp1 == 1.0f) {
307 return x;
308 }
309
310 return x * (metal::log(xp1) / (xp1 - 1.0f));
311}
312
314 float xp1 = 1.0f + static_cast<float>(x);
315 if (xp1 == Limits<float>::max) {
317 }
318 if (xp1 == 1.0f) {
319 return x;
320 }
321
322 return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
323}
324
326// SIMD shuffle ops
328
329inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
330 return as_type<uint64_t>(
331 metal::simd_shuffle_down(as_type<uint2>(data), delta));
332}
333
334inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
335 return as_type<int64_t>(
336 metal::simd_shuffle_down(as_type<uint2>(data), delta));
337}
338
339inline bool simd_shuffle_down(bool data, uint16_t delta) {
340 return simd_shuffle_down(static_cast<uint32_t>(data), delta);
341}
struct _MLX_BFloat16 bfloat16_t
Definition bf16.h:257
#define MLX_MTL_PRAGMA_UNROLL
Definition utils.h:71
METAL_FUNC stride_t elem_to_loc_1(uint elem, constant const stride_t &stride)
Definition utils.h:123
#define instantiate_float_limit(type)
Definition utils.h:44
float log1p(float x)
Definition utils.h:301
METAL_FUNC stride_t elem_to_loc_3(uint3 elem, constant const stride_t strides[3])
Definition utils.h:135
METAL_FUNC stride_t elem_to_loc(uint elem, device const int *shape, device const stride_t *strides, int ndim)
Definition utils.h:77
METAL_FUNC uint2 elem_to_loc_2_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, int ndim)
Definition utils.h:200
size_t ceildiv(size_t N, size_t M)
Compute ceil((float)N/(float)M)
Definition utils.h:296
METAL_FUNC uint3 elem_to_loc_3_nd(uint3 elem, constant const int *shape, constant const size_t *a_strides, constant const size_t *b_strides, constant const size_t *c_strides, int ndim)
Definition utils.h:220
METAL_FUNC size_t elem_to_loc_nd(uint elem, device const int *shape, device const size_t *strides)
Definition utils.h:140
#define instantiate_default_limit(type)
Definition utils.h:24
half float16_t
Definition utils.h:10
METAL_FUNC stride_t elem_to_loc_2(uint2 elem, constant const stride_t strides[2])
Definition utils.h:129
METAL_FUNC bfloat16_t log(bfloat16_t x)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t simd_shuffle_down(bfloat16_t data, ushort delta)
Definition bf16_math.h:391
Definition bf16.h:54
Definition utils.h:17
static const constant U max
Definition utils.h:18
static const constant U finite_max
Definition utils.h:20
static const constant U min
Definition utils.h:19
static const constant U finite_min
Definition utils.h:21