15 static const constant U
max = metal::numeric_limits<U>::max();
16 static const constant U
min = metal::numeric_limits<U>::min();
17 static const constant U
finite_max = metal::numeric_limits<U>::max();
18 static const constant U
finite_min = metal::numeric_limits<U>::min();
21#define instantiate_default_limit(type) \
23 struct Limits<type> { \
24 static constexpr constant type max = metal::numeric_limits<type>::max(); \
25 static constexpr constant type min = metal::numeric_limits<type>::min(); \
26 static constexpr constant type finite_max = \
27 metal::numeric_limits<type>::max(); \
28 static constexpr constant type finite_min = \
29 metal::numeric_limits<type>::min(); \
41#define instantiate_float_limit(type) \
43 struct Limits<type> { \
44 static constexpr constant type max = \
45 metal::numeric_limits<type>::infinity(); \
46 static constexpr constant type min = \
47 -metal::numeric_limits<type>::infinity(); \
48 static constexpr constant type finite_max = \
49 metal::numeric_limits<type>::max(); \
50 static constexpr constant type finite_min = \
51 -metal::numeric_limits<type>::max(); \
60 static constexpr constant
bool max =
true;
61 static constexpr constant
bool min =
false;
68#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
73template <
typename str
ide_t>
76 device
const int* shape,
77 device
const stride_t* strides,
80 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
81 loc += (elem % shape[i]) * strides[i];
87template <
typename str
ide_t>
90 constant
const int* shape,
91 constant
const stride_t* strides,
94 for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
95 loc += (elem % shape[i]) * strides[i];
102template <
typename str
ide_t>
105 constant
const int* shape,
106 constant
const stride_t* strides,
108 stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
109 for (
int d = ndim - 3; d >= 0; --d) {
110 loc += (elem.z % shape[d]) * strides[d];
119template <
typename str
ide_t>
120METAL_FUNC stride_t
elem_to_loc_1(uint elem, constant
const stride_t& stride) {
121 return elem * stride;
124template <
typename str
ide_t>
127 return elem.x * strides[1] + elem.y * strides[0];
130template <
typename str
ide_t>
133 return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
139 device
const int* shape,
140 device
const size_t* strides) {
141 size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
144 for (
int d = NDIM - 2; d >= 0; --d) {
145 elem /= shape[d + 1];
146 loc += (elem % shape[d]) * strides[d];
155 constant
const int shape[NDIM],
156 constant
const size_t strides[NDIM]) {
157 size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
158 for (
int d = NDIM - 3; d >= 0; --d) {
159 loc += (elem.z % shape[d]) * strides[d];
168 constant
const int shape[NDIM],
169 constant
const int64_t strides[NDIM]) {
170 int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
173 for (
int d = NDIM - 2; d >= 0; --d) {
174 elem /= shape[d + 1];
175 loc += (elem % shape[d]) * strides[d];
184 constant
const int shape[NDIM],
185 constant
const int64_t strides[NDIM]) {
186 int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
187 for (
int d = NDIM - 3; d >= 0; --d) {
188 loc += (elem.z % shape[d]) * strides[d];
199 constant
const int* shape,
200 constant
const size_t* a_strides,
201 constant
const size_t* b_strides,
205 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
207 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
208 for (
int d = ndim - 3; d >= 0; --d) {
209 uint l = elem.z % shape[d];
210 loc.x += l * a_strides[d];
211 loc.y += l * b_strides[d];
219 constant
const int* shape,
220 constant
const size_t* a_strides,
221 constant
const size_t* b_strides,
222 constant
const size_t* c_strides,
226 elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
228 elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
230 elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
231 for (
int d = ndim - 3; d >= 0; --d) {
232 uint l = elem.z % shape[d];
233 loc.x += l * a_strides[d];
234 loc.y += l * b_strides[d];
235 loc.z += l * c_strides[d];
247 constant
const int shape[NDIM],
248 constant
const size_t a_strides[NDIM],
249 constant
const size_t b_strides[NDIM]) {
252 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
254 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
255 for (
int d = NDIM - 3; d >= 0; --d) {
256 uint l = elem.z % shape[d];
257 loc.x += l * a_strides[d];
258 loc.y += l * b_strides[d];
267 constant
const int shape[NDIM],
268 constant
const size_t a_strides[NDIM],
269 constant
const size_t b_strides[NDIM],
270 constant
const size_t c_strides[NDIM]) {
273 elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
275 elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
277 elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
278 for (
int d = NDIM - 3; d >= 0; --d) {
279 uint l = elem.z % shape[d];
280 loc.x += l * a_strides[d];
281 loc.y += l * b_strides[d];
282 loc.z += l * c_strides[d];
294 return (N + M - 1) / M;
299 float xp1 = 1.0f + x;
311 float xp1 = 1.0f +
static_cast<float>(x);
327 return as_type<uint64_t>(
332 return as_type<int64_t>(
static const constant U max
Definition utils.h:15
static const constant U finite_max
Definition utils.h:17
static const constant U min
Definition utils.h:16
static const constant U finite_min
Definition utils.h:18