24 static const constant U
max = metal::numeric_limits<U>::max();
25 static const constant U
min = metal::numeric_limits<U>::min();
26 static const constant U
finite_max = metal::numeric_limits<U>::max();
27 static const constant U
finite_min = metal::numeric_limits<U>::min();
30#define instantiate_default_limit(type) \
32 struct Limits<type> { \
33 static constexpr constant type max = metal::numeric_limits<type>::max(); \
34 static constexpr constant type min = metal::numeric_limits<type>::min(); \
35 static constexpr constant type finite_max = \
36 metal::numeric_limits<type>::max(); \
37 static constexpr constant type finite_min = \
38 metal::numeric_limits<type>::min(); \
50#define instantiate_float_limit(type) \
52 struct Limits<type> { \
53 static constexpr constant type max = \
54 metal::numeric_limits<type>::infinity(); \
55 static constexpr constant type min = \
56 -metal::numeric_limits<type>::infinity(); \
57 static constexpr constant type finite_max = \
58 metal::numeric_limits<type>::max(); \
59 static constexpr constant type finite_min = \
60 -metal::numeric_limits<type>::max(); \
69 static constexpr constant
bool max =
true;
70 static constexpr constant
bool min =
false;
76 metal::numeric_limits<float>::infinity(),
77 metal::numeric_limits<float>::infinity());
79 -metal::numeric_limits<float>::infinity(),
80 -metal::numeric_limits<float>::infinity());
87#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
92template <
typename IdxT =
int64_t>
95 constant
const int* shape,
96 constant
const int64_t* strides,
99 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
100 loc += (elem % shape[i]) * IdxT(strides[i]);
107template <
typename IdxT =
int64_t>
110 constant
const int* shape,
111 constant
const int64_t* strides,
114 elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
115 for (
int d = ndim - 3; d >= 0; --d) {
116 loc += (elem.z % shape[d]) * IdxT(strides[d]);
125template <
typename IdxT =
int64_t>
127 return elem * IdxT(stride);
130template <
typename IdxT =
int64_t>
131METAL_FUNC IdxT
elem_to_loc_2(uint2 elem, constant
const int64_t strides[2]) {
132 return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
135template <
typename IdxT =
int64_t>
136METAL_FUNC IdxT
elem_to_loc_3(uint3 elem, constant
const int64_t strides[3]) {
137 return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
138 elem.z * IdxT(strides[0]);
144template <
typename IdxT =
int64_t>
147 constant
const int* shape,
148 constant
const int64_t* a_strides,
149 constant
const int64_t* b_strides,
153 elem.x * IdxT(a_strides[ndim - 1]) +
154 IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
156 elem.x * IdxT(b_strides[ndim - 1]) +
157 elem.y * IdxT(b_strides[ndim - 2]))};
158 for (
int d = ndim - 3; d >= 0; --d) {
159 uint l = elem.z % shape[d];
160 loc.x += l * IdxT(a_strides[d]);
161 loc.y += l * IdxT(b_strides[d]);
167template <
typename IdxT =
int64_t>
170 constant
const int* shape,
171 constant
const int64_t* a_strides,
172 constant
const int64_t* b_strides,
173 constant
const int64_t* c_strides,
176 IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
177 IdxT(elem.y * IdxT(a_strides[ndim - 2])),
178 IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
179 IdxT(elem.y * IdxT(b_strides[ndim - 2])),
180 IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
181 IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
182 for (
int d = ndim - 3; d >= 0; --d) {
183 uint l = elem.z % shape[d];
184 loc.x += l * IdxT(a_strides[d]);
185 loc.y += l * IdxT(b_strides[d]);
186 loc.z += l * IdxT(c_strides[d]);
196template <
int DIM,
typename OffsetT =
size_t,
bool General = true>
205 void next(
const constant
int* shape,
const constant int64_t* strides) {
218 void next(
int n,
const constant
int* shape,
const constant int64_t* strides) {
227 if (extra >= shape[
dim - 1]) {
229 extra = extra % shape[
dim - 1];
236 next(extra, shape, strides);
246template <
typename OffsetT>
254 void next(
const constant
int* shape,
const constant int64_t* strides) {
259 offset += OffsetT(strides[0]);
263 void next(
int n,
const constant
int* shape,
const constant int64_t* strides) {
277template <
typename OffsetT>
283 void next(
const constant
int*,
const constant int64_t* strides) {
284 offset += OffsetT(strides[0]);
287 void next(
int n,
const constant
int*,
const constant int64_t* strides) {
288 offset += n * OffsetT(strides[0]);
301template <
typename T,
typename U>
303 return (N + M - 1) / M;
308 float xp1 = 1.0f + x;
320 float xp1 = 1.0f +
static_cast<float>(x);
336 return as_type<uint64_t>(
341 return as_type<int64_t>(
374 as_type<uint2>(data), as_type<uint2>(filling), delta));
380 as_type<uint2>(data), as_type<uint2>(filling), delta));
385 static_cast<uint32_t
>(data),
static_cast<uint32_t
>(filling), delta);
415template <
bool condition,
typename T,
typename U>
420template <
typename T,
typename U>
U type
Definition utils.h:417
T type
Definition utils.h:422
static constexpr constant bool min
Definition utils.h:70
static const constant U max
Definition utils.h:24
static const constant U finite_max
Definition utils.h:26
static const constant U min
Definition utils.h:25
static constexpr constant complex64_t min
Definition utils.h:78
static constexpr constant complex64_t max
Definition utils.h:75
static constexpr constant bool max
Definition utils.h:69
static const constant U finite_min
Definition utils.h:27
LoopedElemToLoc(int)
Definition utils.h:281
uint index
Definition utils.h:250
OffsetT offset
Definition utils.h:249
LoopedElemToLoc(int dim)
Definition utils.h:203
void next(int n, const constant int *, const constant int64_t *strides)
Definition utils.h:287
OffsetT location()
Definition utils.h:272
void next(int n, const constant int *shape, const constant int64_t *strides)
Definition utils.h:263
int dim
Definition utils.h:248
OffsetT location()
Definition utils.h:291
LoopedElemToLoc< DIM - 1, OffsetT, General > inner_looper
Definition utils.h:199
void next(const constant int *shape, const constant int64_t *strides)
Definition utils.h:205
void next(const constant int *, const constant int64_t *strides)
Definition utils.h:283
OffsetT location()
Definition utils.h:241
LoopedElemToLoc(int dim)
Definition utils.h:252
int index
Definition utils.h:201
OffsetT offset
Definition utils.h:200
void next(const constant int *shape, const constant int64_t *strides)
Definition utils.h:254
void next(int n, const constant int *shape, const constant int64_t *strides)
Definition utils.h:218
OffsetT offset
Definition utils.h:279
int dim
Definition utils.h:198
float imag
Definition complex.h:22
float real
Definition complex.h:21