18 static const constant U
max = metal::numeric_limits<U>::max();
19 static const constant U
min = metal::numeric_limits<U>::min();
20 static const constant U
finite_max = metal::numeric_limits<U>::max();
21 static const constant U
finite_min = metal::numeric_limits<U>::min();
24#define instantiate_default_limit(type) \
26 struct Limits<type> { \
27 static constexpr constant type max = metal::numeric_limits<type>::max(); \
28 static constexpr constant type min = metal::numeric_limits<type>::min(); \
29 static constexpr constant type finite_max = \
30 metal::numeric_limits<type>::max(); \
31 static constexpr constant type finite_min = \
32 metal::numeric_limits<type>::min(); \
44#define instantiate_float_limit(type) \
46 struct Limits<type> { \
47 static constexpr constant type max = \
48 metal::numeric_limits<type>::infinity(); \
49 static constexpr constant type min = \
50 -metal::numeric_limits<type>::infinity(); \
51 static constexpr constant type finite_max = \
52 metal::numeric_limits<type>::max(); \
53 static constexpr constant type finite_min = \
54 -metal::numeric_limits<type>::max(); \
63 static constexpr constant
bool max =
true;
64 static constexpr constant
bool min =
false;
71#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
76template <
typename str
ide_t>
79 device
const int* shape,
80 device
const stride_t* strides,
83 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
84 loc += (elem % shape[i]) * strides[i];
90template <
typename str
ide_t>
93 constant
const int* shape,
94 constant
const stride_t* strides,
97 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
98 loc += (elem % shape[i]) * strides[i];
105template <
typename str
ide_t>
108 constant
const int* shape,
109 constant
const stride_t* strides,
111 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
112 for (
int d = ndim - 3; d >= 0; --d) {
113 loc += (elem.z % shape[d]) * strides[d];
122template <
typename str
ide_t>
123METAL_FUNC stride_t
elem_to_loc_1(uint elem, constant
const stride_t& stride) {
124 return elem * stride;
127template <
typename str
ide_t>
130 return elem.x * strides[1] + elem.y * strides[0];
133template <
typename str
ide_t>
136 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
142 device
const int* shape,
143 device
const size_t* strides) {
144 size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
147 for (
int d = NDIM - 2; d >= 0; --d) {
148 elem /= shape[d + 1];
149 loc += (elem % shape[d]) * strides[d];
158 constant
const int shape[NDIM],
159 constant
const size_t strides[NDIM]) {
160 size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
161 for (
int d = NDIM - 3; d >= 0; --d) {
162 loc += (elem.z % shape[d]) * strides[d];
171 constant
const int shape[NDIM],
172 constant
const int64_t strides[NDIM]) {
173 int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
176 for (
int d = NDIM - 2; d >= 0; --d) {
177 elem /= shape[d + 1];
178 loc += (elem % shape[d]) * strides[d];
187 constant
const int shape[NDIM],
188 constant
const int64_t strides[NDIM]) {
189 int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
190 for (
int d = NDIM - 3; d >= 0; --d) {
191 loc += (elem.z % shape[d]) * strides[d];
202 constant
const int* shape,
203 constant
const size_t* a_strides,
204 constant
const size_t* b_strides,
208 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
210 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
211 for (
int d = ndim - 3; d >= 0; --d) {
212 uint l = elem.z % shape[d];
213 loc.x += l * a_strides[d];
214 loc.y += l * b_strides[d];
222 constant
const int* shape,
223 constant
const size_t* a_strides,
224 constant
const size_t* b_strides,
225 constant
const size_t* c_strides,
229 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
231 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
233 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
234 for (
int d = ndim - 3; d >= 0; --d) {
235 uint l = elem.z % shape[d];
236 loc.x += l * a_strides[d];
237 loc.y += l * b_strides[d];
238 loc.z += l * c_strides[d];
250 constant
const int shape[NDIM],
251 constant
const size_t a_strides[NDIM],
252 constant
const size_t b_strides[NDIM]) {
255 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
257 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
258 for (
int d = NDIM - 3; d >= 0; --d) {
259 uint l = elem.z % shape[d];
260 loc.x += l * a_strides[d];
261 loc.y += l * b_strides[d];
270 constant
const int shape[NDIM],
271 constant
const size_t a_strides[NDIM],
272 constant
const size_t b_strides[NDIM],
273 constant
const size_t c_strides[NDIM]) {
276 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
278 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
280 elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
281 for (
int d = NDIM - 3; d >= 0; --d) {
282 uint l = elem.z % shape[d];
283 loc.x += l * a_strides[d];
284 loc.y += l * b_strides[d];
285 loc.z += l * c_strides[d];
297 return (N + M - 1) / M;
302 float xp1 = 1.0f + x;
314 float xp1 = 1.0f +
static_cast<float>(x);
330 return as_type<uint64_t>(
335 return as_type<int64_t>(
static const constant U max
Definition utils.h:18
static const constant U finite_max
Definition utils.h:20
static const constant U min
Definition utils.h:19
static const constant U finite_min
Definition utils.h:21