24 static const constant U
max = metal::numeric_limits<U>::max();
25 static const constant U
min = metal::numeric_limits<U>::min();
26 static const constant U
finite_max = metal::numeric_limits<U>::max();
27 static const constant U
finite_min = metal::numeric_limits<U>::min();
30#define instantiate_default_limit(type) \
32 struct Limits<type> { \
33 static constexpr constant type max = metal::numeric_limits<type>::max(); \
34 static constexpr constant type min = metal::numeric_limits<type>::min(); \
35 static constexpr constant type finite_max = \
36 metal::numeric_limits<type>::max(); \
37 static constexpr constant type finite_min = \
38 metal::numeric_limits<type>::min(); \
50#define instantiate_float_limit(type) \
52 struct Limits<type> { \
53 static constexpr constant type max = \
54 metal::numeric_limits<type>::infinity(); \
55 static constexpr constant type min = \
56 -metal::numeric_limits<type>::infinity(); \
57 static constexpr constant type finite_max = \
58 metal::numeric_limits<type>::max(); \
59 static constexpr constant type finite_min = \
60 -metal::numeric_limits<type>::max(); \
69 static constexpr constant
bool max =
true;
70 static constexpr constant
bool min =
false;
76 metal::numeric_limits<float>::infinity(),
77 metal::numeric_limits<float>::infinity());
79 -metal::numeric_limits<float>::infinity(),
80 -metal::numeric_limits<float>::infinity());
87#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
92template <
typename Str
ideT,
typename IdxT = Str
ideT>
95 constant
const int* shape,
96 constant
const StrideT* strides,
99 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
100 loc += (elem % shape[i]) * IdxT(strides[i]);
106template <
typename Str
ideT,
typename IdxT = Str
ideT>
109 constant
const int* shape,
110 constant
const StrideT* strides,
113 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
114 loc += (elem % shape[i]) * IdxT(strides[i]);
121template <
typename Str
ideT,
typename IdxT = Str
ideT>
124 constant
const int* shape,
125 constant
const StrideT* strides,
128 elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
129 for (
int d = ndim - 3; d >= 0; --d) {
130 loc += (elem.z % shape[d]) * IdxT(strides[d]);
139template <
typename Str
ideT,
typename IdxT = Str
ideT>
141 return elem * IdxT(stride);
144template <
typename Str
ideT,
typename IdxT = Str
ideT>
145METAL_FUNC IdxT
elem_to_loc_2(uint2 elem, constant
const StrideT strides[2]) {
146 return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
149template <
typename Str
ideT,
typename IdxT = Str
ideT>
150METAL_FUNC IdxT
elem_to_loc_3(uint3 elem, constant
const StrideT strides[3]) {
151 return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
152 elem.z * IdxT(strides[0]);
158template <
typename Str
ideT,
typename IdxT = Str
ideT>
161 constant
const int* shape,
162 constant
const StrideT* a_strides,
163 constant
const StrideT* b_strides,
167 elem.x * IdxT(a_strides[ndim - 1]) +
168 IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
170 elem.x * IdxT(b_strides[ndim - 1]) +
171 elem.y * IdxT(b_strides[ndim - 2]))};
172 for (
int d = ndim - 3; d >= 0; --d) {
173 uint l = elem.z % shape[d];
174 loc.x += l * IdxT(a_strides[d]);
175 loc.y += l * IdxT(b_strides[d]);
181template <
typename IdxT =
size_t>
184 constant
const int* shape,
185 constant
const size_t* a_strides,
186 constant
const size_t* b_strides,
187 constant
const size_t* c_strides,
190 elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
191 elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
192 elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
193 for (
int d = ndim - 3; d >= 0; --d) {
194 uint l = elem.z % shape[d];
195 loc.x += l * IdxT(a_strides[d]);
196 loc.y += l * IdxT(b_strides[d]);
197 loc.z += l * IdxT(c_strides[d]);
207template <
int DIM,
typename OffsetT =
size_t,
bool General = true>
216 void next(
const constant
int* shape,
const constant
size_t* strides) {
229 void next(
int n,
const constant
int* shape,
const constant
size_t* strides) {
238 if (extra >= shape[
dim - 1]) {
240 extra = extra % shape[
dim - 1];
247 next(extra, shape, strides);
257template <
typename OffsetT>
265 void next(
const constant
int* shape,
const constant
size_t* strides) {
270 offset += OffsetT(strides[0]);
274 void next(
int n,
const constant
int* shape,
const constant
size_t* strides) {
288template <
typename OffsetT>
294 void next(
const constant
int*,
const constant
size_t* strides) {
295 offset += OffsetT(strides[0]);
298 void next(
int n,
const constant
int*,
const constant
size_t* strides) {
299 offset += n * OffsetT(strides[0]);
312template <
typename T,
typename U>
314 return (N + M - 1) / M;
319 float xp1 = 1.0f + x;
331 float xp1 = 1.0f +
static_cast<float>(x);
347 return as_type<uint64_t>(
352 return as_type<int64_t>(
385 as_type<uint2>(data), as_type<uint2>(filling), delta));
391 as_type<uint2>(data), as_type<uint2>(filling), delta));
396 static_cast<uint32_t
>(data),
static_cast<uint32_t
>(filling), delta);
426template <
bool condition,
typename T,
typename U>
431template <
typename T,
typename U>
T type
Definition utils.h:433
U type
Definition utils.h:428
static const constant U max
Definition utils.h:24
static const constant U finite_max
Definition utils.h:26
static const constant U min
Definition utils.h:25
static const constant U finite_min
Definition utils.h:27
void next(const constant int *, const constant size_t *strides)
Definition utils.h:294
LoopedElemToLoc(int)
Definition utils.h:292
OffsetT location()
Definition utils.h:302
void next(int n, const constant int *, const constant size_t *strides)
Definition utils.h:298
OffsetT location()
Definition utils.h:283
int dim
Definition utils.h:259
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:274
LoopedElemToLoc(int dim)
Definition utils.h:263
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:265
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:216
LoopedElemToLoc(int dim)
Definition utils.h:214
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:229
LoopedElemToLoc< DIM - 1, OffsetT, General > inner_looper
Definition utils.h:210
OffsetT location()
Definition utils.h:252
int index
Definition utils.h:212
OffsetT offset
Definition utils.h:211
int dim
Definition utils.h:209
float imag
Definition complex.h:22
float real
Definition complex.h:21