18 static const constant U
max = metal::numeric_limits<U>::max();
19 static const constant U
min = metal::numeric_limits<U>::min();
20 static const constant U
finite_max = metal::numeric_limits<U>::max();
21 static const constant U
finite_min = metal::numeric_limits<U>::min();
24#define instantiate_default_limit(type) \
26 struct Limits<type> { \
27 static constexpr constant type max = metal::numeric_limits<type>::max(); \
28 static constexpr constant type min = metal::numeric_limits<type>::min(); \
29 static constexpr constant type finite_max = \
30 metal::numeric_limits<type>::max(); \
31 static constexpr constant type finite_min = \
32 metal::numeric_limits<type>::min(); \
44#define instantiate_float_limit(type) \
46 struct Limits<type> { \
47 static constexpr constant type max = \
48 metal::numeric_limits<type>::infinity(); \
49 static constexpr constant type min = \
50 -metal::numeric_limits<type>::infinity(); \
51 static constexpr constant type finite_max = \
52 metal::numeric_limits<type>::max(); \
53 static constexpr constant type finite_min = \
54 -metal::numeric_limits<type>::max(); \
63 static constexpr constant
bool max =
true;
64 static constexpr constant
bool min =
false;
70 metal::numeric_limits<float>::infinity(),
71 metal::numeric_limits<float>::infinity());
73 -metal::numeric_limits<float>::infinity(),
74 -metal::numeric_limits<float>::infinity());
81#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
86template <
typename str
ide_t>
89 constant
const int* shape,
90 constant
const stride_t* strides,
93 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
94 loc += (elem % shape[i]) * strides[i];
100template <
typename str
ide_t>
103 constant
const int* shape,
104 constant
const stride_t* strides,
107 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
108 loc += (elem % shape[i]) * strides[i];
115template <
typename str
ide_t>
118 constant
const int* shape,
119 constant
const stride_t* strides,
121 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
122 for (
int d = ndim - 3; d >= 0; --d) {
123 loc += (elem.z % shape[d]) * strides[d];
132template <
typename str
ide_t>
133METAL_FUNC stride_t
elem_to_loc_1(uint elem, constant
const stride_t& stride) {
134 return elem * stride;
137template <
typename str
ide_t>
140 return elem.x * strides[1] + elem.y * strides[0];
143template <
typename str
ide_t>
146 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
152template <
typename str
ide_t>
155 constant
const int* shape,
156 constant
const stride_t* a_strides,
157 constant
const stride_t* b_strides,
160 ulong(elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
161 ulong(elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
162 for (
int d = ndim - 3; d >= 0; --d) {
163 uint l = elem.z % shape[d];
164 loc.x += l * a_strides[d];
165 loc.y += l * b_strides[d];
173 constant
const int* shape,
174 constant
const size_t* a_strides,
175 constant
const size_t* b_strides,
176 constant
const size_t* c_strides,
179 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2],
180 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2],
181 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2]};
182 for (
int d = ndim - 3; d >= 0; --d) {
183 uint l = elem.z % shape[d];
184 loc.x += l * a_strides[d];
185 loc.y += l * b_strides[d];
186 loc.z += l * c_strides[d];
196template <
int dim,
typename offset_t =
size_t>
202 void next(
const constant
int* shape,
const constant
size_t* strides) {
204 offset += strides[dim - 1];
206 if (
index >= shape[dim - 1]) {
213 void next(
int n,
const constant
int* shape,
const constant
size_t* strides) {
215 offset += n * strides[dim - 1];
217 if (
index >= shape[dim - 1]) {
218 int extra =
index - shape[dim - 1];
223 next(extra, shape, strides);
229 location(offset_t,
const constant
int*,
const constant
size_t*,
int) {
234template <
typename offset_t>
238 void next(
const constant
int*,
const constant
size_t* strides) {
242 void next(
int n,
const constant
int*,
const constant
size_t* strides) {
247 location(offset_t,
const constant
int*,
const constant
size_t*,
int) {
252template <
typename offset_t>
254 void next(
const constant
int*,
const constant
size_t*) {}
255 void next(
int,
const constant
int*,
const constant
size_t*) {}
259 const constant
int* shape,
260 const constant
size_t* strides,
271template <
typename T,
typename U>
273 return (N + M - 1) / M;
278 float xp1 = 1.0f + x;
290 float xp1 = 1.0f +
static_cast<float>(x);
306 return as_type<uint64_t>(
311 return as_type<int64_t>(
344 as_type<uint2>(data), as_type<uint2>(filling), delta));
350 as_type<uint2>(data), as_type<uint2>(filling), delta));
355 static_cast<uint32_t
>(data),
static_cast<uint32_t
>(filling), delta);
std::vector< ptrdiff_t > stride_t
Definition pocketfft.h:103
static const constant U max
Definition utils.h:18
static const constant U finite_max
Definition utils.h:20
static const constant U min
Definition utils.h:19
static const constant U finite_min
Definition utils.h:21
float imag
Definition complex.h:22
float real
Definition complex.h:21
void next(int, const constant int *, const constant size_t *)
Definition utils.h:255
offset_t location(offset_t idx, const constant int *shape, const constant size_t *strides, int ndim)
Definition utils.h:257
void next(const constant int *, const constant size_t *)
Definition utils.h:254
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:247
void next(const constant int *, const constant size_t *strides)
Definition utils.h:238
void next(int n, const constant int *, const constant size_t *strides)
Definition utils.h:242
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:202
offset_t offset
Definition utils.h:199
int index
Definition utils.h:200
looped_elem_to_loc< dim - 1, offset_t > inner_looper
Definition utils.h:198
offset_t location(offset_t, const constant int *, const constant size_t *, int)
Definition utils.h:229
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:213