18 static const constant U
max = metal::numeric_limits<U>::max();
19 static const constant U
min = metal::numeric_limits<U>::min();
20 static const constant U
finite_max = metal::numeric_limits<U>::max();
21 static const constant U
finite_min = metal::numeric_limits<U>::min();
24#define instantiate_default_limit(type) \
26 struct Limits<type> { \
27 static constexpr constant type max = metal::numeric_limits<type>::max(); \
28 static constexpr constant type min = metal::numeric_limits<type>::min(); \
29 static constexpr constant type finite_max = \
30 metal::numeric_limits<type>::max(); \
31 static constexpr constant type finite_min = \
32 metal::numeric_limits<type>::min(); \
44#define instantiate_float_limit(type) \
46 struct Limits<type> { \
47 static constexpr constant type max = \
48 metal::numeric_limits<type>::infinity(); \
49 static constexpr constant type min = \
50 -metal::numeric_limits<type>::infinity(); \
51 static constexpr constant type finite_max = \
52 metal::numeric_limits<type>::max(); \
53 static constexpr constant type finite_min = \
54 -metal::numeric_limits<type>::max(); \
63 static constexpr constant
bool max =
true;
64 static constexpr constant
bool min =
false;
70 metal::numeric_limits<float>::infinity(),
71 metal::numeric_limits<float>::infinity());
73 -metal::numeric_limits<float>::infinity(),
74 -metal::numeric_limits<float>::infinity());
81#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
86template <
typename str
ide_t>
89 device
const int* shape,
90 device
const stride_t* strides,
93 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
94 loc += (elem % shape[i]) * strides[i];
100template <
typename str
ide_t>
103 constant
const int* shape,
104 constant
const stride_t* strides,
107 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
108 loc += (elem % shape[i]) * strides[i];
114template <
typename str
ide_t>
117 device
const int* shape,
118 device
const stride_t* strides,
121 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
122 loc += (elem % shape[i]) * strides[i];
128template <
typename str
ide_t>
131 constant
const int* shape,
132 constant
const stride_t* strides,
135 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
136 loc += (elem % shape[i]) * strides[i];
143template <
typename str
ide_t>
146 constant
const int* shape,
147 constant
const stride_t* strides,
149 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
150 for (
int d = ndim - 3; d >= 0; --d) {
151 loc += (elem.z % shape[d]) * strides[d];
160template <
typename str
ide_t>
161METAL_FUNC stride_t
elem_to_loc_1(uint elem, constant
const stride_t& stride) {
162 return elem * stride;
165template <
typename str
ide_t>
168 return elem.x * strides[1] + elem.y * strides[0];
171template <
typename str
ide_t>
174 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
180 device
const int* shape,
181 device
const size_t* strides) {
182 size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
185 for (
int d = NDIM - 2; d >= 0; --d) {
186 elem /= shape[d + 1];
187 loc += (elem % shape[d]) * strides[d];
196 constant
const int shape[NDIM],
197 constant
const size_t strides[NDIM]) {
198 size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
199 for (
int d = NDIM - 3; d >= 0; --d) {
200 loc += (elem.z % shape[d]) * strides[d];
209 constant
const int shape[NDIM],
210 constant
const int64_t strides[NDIM]) {
211 int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
214 for (
int d = NDIM - 2; d >= 0; --d) {
215 elem /= shape[d + 1];
216 loc += (elem % shape[d]) * strides[d];
225 constant
const int shape[NDIM],
226 constant
const int64_t strides[NDIM]) {
227 int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
228 for (
int d = NDIM - 3; d >= 0; --d) {
229 loc += (elem.z % shape[d]) * strides[d];
240 constant
const int* shape,
241 constant
const size_t* a_strides,
242 constant
const size_t* b_strides,
246 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
248 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
249 for (
int d = ndim - 3; d >= 0; --d) {
250 uint l = elem.z % shape[d];
251 loc.x += l * a_strides[d];
252 loc.y += l * b_strides[d];
260 constant
const int* shape,
261 constant
const size_t* a_strides,
262 constant
const size_t* b_strides,
263 constant
const size_t* c_strides,
267 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
269 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
271 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
272 for (
int d = ndim - 3; d >= 0; --d) {
273 uint l = elem.z % shape[d];
274 loc.x += l * a_strides[d];
275 loc.y += l * b_strides[d];
276 loc.z += l * c_strides[d];
288 constant
const int shape[NDIM],
289 constant
const size_t a_strides[NDIM],
290 constant
const size_t b_strides[NDIM]) {
293 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
295 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
296 for (
int d = NDIM - 3; d >= 0; --d) {
297 uint l = elem.z % shape[d];
298 loc.x += l * a_strides[d];
299 loc.y += l * b_strides[d];
308 constant
const int shape[NDIM],
309 constant
const size_t a_strides[NDIM],
310 constant
const size_t b_strides[NDIM],
311 constant
const size_t c_strides[NDIM]) {
314 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
316 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
318 elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
319 for (
int d = NDIM - 3; d >= 0; --d) {
320 uint l = elem.z % shape[d];
321 loc.x += l * a_strides[d];
322 loc.y += l * b_strides[d];
323 loc.z += l * c_strides[d];
333template <
int dim,
typename offset_t =
size_t>
339 void next(
const constant
int* shape,
const constant
size_t* strides) {
341 offset += strides[dim - 1];
343 if (
index >= shape[dim - 1]) {
350 void next(
int n,
const constant
int* shape,
const constant
size_t* strides) {
352 offset += n * strides[dim - 1];
354 if (
index >= shape[dim - 1]) {
355 int extra =
index - shape[dim - 1];
360 next(extra, shape, strides);
366 location(offset_t,
const constant
int*,
const constant
size_t*,
int) {
371template <
typename offset_t>
375 void next(
const constant
int*,
const constant
size_t* strides) {
379 void next(
int n,
const constant
int*,
const constant
size_t* strides) {
384 location(offset_t,
const constant
int*,
const constant
size_t*,
int) {
389template <
typename offset_t>
391 void next(
const constant
int*,
const constant
size_t*) {}
392 void next(
int,
const constant
int*,
const constant
size_t*) {}
396 const constant
int* shape,
397 const constant
size_t* strides,
408template <
typename T,
typename U>
410 return (N + M - 1) / M;
415 float xp1 = 1.0f + x;
427 float xp1 = 1.0f +
static_cast<float>(x);
443 return as_type<uint64_t>(
448 return as_type<int64_t>(
static const constant U max
Definition utils.h:18
static const constant U finite_max
Definition utils.h:20
static const constant U min
Definition utils.h:19
static const constant U finite_min
Definition utils.h:21
float imag
Definition complex.h:22
float real
Definition complex.h:21
void next(int, const constant int *, const constant size_t *)
Definition utils.h:392
offset_t location(offset_t idx, const constant int *shape, const constant size_t *strides, int ndim)
Definition utils.h:394
void next(const constant int *, const constant size_t *)
Definition utils.h:391
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:384
void next(const constant int *, const constant size_t *strides)
Definition utils.h:375
void next(int n, const constant int *, const constant size_t *strides)
Definition utils.h:379
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:339
offset_t offset
Definition utils.h:336
int index
Definition utils.h:337
looped_elem_to_loc< dim - 1, offset_t > inner_looper
Definition utils.h:335
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:366
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:350