24  static const constant U 
max = metal::numeric_limits<U>::max();
 
   25  static const constant U 
min = metal::numeric_limits<U>::min();
 
   26  static const constant U 
finite_max = metal::numeric_limits<U>::max();
 
   27  static const constant U 
finite_min = metal::numeric_limits<U>::min();
 
 
   30#define instantiate_default_limit(type)                                      \ 
   32  struct Limits<type> {                                                      \ 
   33    static constexpr constant type max = metal::numeric_limits<type>::max(); \ 
   34    static constexpr constant type min = metal::numeric_limits<type>::min(); \ 
   35    static constexpr constant type finite_max =                              \ 
   36        metal::numeric_limits<type>::max();                                  \ 
   37    static constexpr constant type finite_min =                              \ 
   38        metal::numeric_limits<type>::min();                                  \ 
 
   50#define instantiate_float_limit(type)             \ 
   52  struct Limits<type> {                           \ 
   53    static constexpr constant type max =          \ 
   54        metal::numeric_limits<type>::infinity();  \ 
   55    static constexpr constant type min =          \ 
   56        -metal::numeric_limits<type>::infinity(); \ 
   57    static constexpr constant type finite_max =   \ 
   58        metal::numeric_limits<type>::max();       \ 
   59    static constexpr constant type finite_min =   \ 
   60        -metal::numeric_limits<type>::max();      \ 
 
   69  static constexpr constant 
bool max = 
true;
 
   70  static constexpr constant 
bool min = 
false;
 
 
   76      metal::numeric_limits<float>::infinity(),
 
   77      metal::numeric_limits<float>::infinity());
 
   79      -metal::numeric_limits<float>::infinity(),
 
   80      -metal::numeric_limits<float>::infinity());
 
 
   87#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") 
   92template <
typename Str
ideT, 
typename IdxT = Str
ideT>
 
   95    constant 
const int* shape,
 
   96    constant 
const StrideT* strides,
 
   99  for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
 
  100    loc += (elem % shape[i]) * IdxT(strides[i]);
 
 
  106template <
typename Str
ideT, 
typename IdxT = Str
ideT>
 
  109    constant 
const int* shape,
 
  110    constant 
const StrideT* strides,
 
  113  for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
 
  114    loc += (elem % shape[i]) * IdxT(strides[i]);
 
 
  121template <
typename Str
ideT, 
typename IdxT = Str
ideT>
 
  124    constant 
const int* shape,
 
  125    constant 
const StrideT* strides,
 
  128      elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
 
  129  for (
int d = ndim - 3; d >= 0; --d) {
 
  130    loc += (elem.z % shape[d]) * IdxT(strides[d]);
 
 
  139template <
typename Str
ideT, 
typename IdxT = Str
ideT>
 
  141  return elem * IdxT(stride);
 
 
  144template <
typename Str
ideT, 
typename IdxT = Str
ideT>
 
  145METAL_FUNC IdxT 
elem_to_loc_2(uint2 elem, constant 
const StrideT strides[2]) {
 
  146  return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
 
 
  149template <
typename Str
ideT, 
typename IdxT = Str
ideT>
 
  150METAL_FUNC IdxT 
elem_to_loc_3(uint3 elem, constant 
const StrideT strides[3]) {
 
  151  return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
 
  152      elem.z * IdxT(strides[0]);
 
 
  158template <
typename Str
ideT, 
typename IdxT = Str
ideT>
 
  161    constant 
const int* shape,
 
  162    constant 
const StrideT* a_strides,
 
  163    constant 
const StrideT* b_strides,
 
  167          elem.x * IdxT(a_strides[ndim - 1]) +
 
  168          IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
 
  170          elem.x * IdxT(b_strides[ndim - 1]) +
 
  171          elem.y * IdxT(b_strides[ndim - 2]))};
 
  172  for (
int d = ndim - 3; d >= 0; --d) {
 
  173    uint l = elem.z % shape[d];
 
  174    loc.x += l * IdxT(a_strides[d]);
 
  175    loc.y += l * IdxT(b_strides[d]);
 
 
  181template <
typename IdxT = 
size_t>
 
  184    constant 
const int* shape,
 
  185    constant 
const size_t* a_strides,
 
  186    constant 
const size_t* b_strides,
 
  187    constant 
const size_t* c_strides,
 
  190      elem.x * IdxT(a_strides[ndim - 1]) + elem.y * IdxT(a_strides[ndim - 2]),
 
  191      elem.x * IdxT(b_strides[ndim - 1]) + elem.y * IdxT(b_strides[ndim - 2]),
 
  192      elem.x * IdxT(c_strides[ndim - 1]) + elem.y * IdxT(c_strides[ndim - 2])};
 
  193  for (
int d = ndim - 3; d >= 0; --d) {
 
  194    uint l = elem.z % shape[d];
 
  195    loc.x += l * IdxT(a_strides[d]);
 
  196    loc.y += l * IdxT(b_strides[d]);
 
  197    loc.z += l * IdxT(c_strides[d]);
 
 
  207template <
int DIM, 
typename OffsetT = 
size_t, 
bool General = true>
 
  216  void next(
const constant 
int* shape, 
const constant 
size_t* strides) {
 
 
  229  void next(
int n, 
const constant 
int* shape, 
const constant 
size_t* strides) {
 
  238      if (extra >= shape[
dim - 1]) {
 
  240        extra = extra % shape[
dim - 1];
 
  247        next(extra, shape, strides);
 
 
 
  257template <
typename OffsetT>
 
  265  void next(
const constant 
int* shape, 
const constant 
size_t* strides) {
 
  270      offset += OffsetT(strides[0]);
 
 
  274  void next(
int n, 
const constant 
int* shape, 
const constant 
size_t* strides) {
 
 
 
  288template <
typename OffsetT>
 
  294  void next(
const constant 
int*, 
const constant 
size_t* strides) {
 
  295    offset += OffsetT(strides[0]);
 
 
  298  void next(
int n, 
const constant 
int*, 
const constant 
size_t* strides) {
 
  299    offset += n * OffsetT(strides[0]);
 
 
 
  312template <
typename T, 
typename U>
 
  314  return (N + M - 1) / M;
 
 
  319  float xp1 = 1.0f + x;
 
 
  331  float xp1 = 1.0f + 
static_cast<float>(x);
 
 
  347  return as_type<uint64_t>(
 
 
  352  return as_type<int64_t>(
 
 
  385      as_type<uint2>(data), as_type<uint2>(filling), delta));
 
 
  391      as_type<uint2>(data), as_type<uint2>(filling), delta));
 
 
  396      static_cast<uint32_t
>(data), 
static_cast<uint32_t
>(filling), delta);
 
 
  426template <
bool condition, 
typename T, 
typename U>
 
  431template <
typename T, 
typename U>
 
T type
Definition utils.h:433
 
U type
Definition utils.h:428
 
static const constant U max
Definition utils.h:24
 
static const constant U finite_max
Definition utils.h:26
 
static const constant U min
Definition utils.h:25
 
static const constant U finite_min
Definition utils.h:27
 
void next(const constant int *, const constant size_t *strides)
Definition utils.h:294
 
LoopedElemToLoc(int)
Definition utils.h:292
 
OffsetT location()
Definition utils.h:302
 
void next(int n, const constant int *, const constant size_t *strides)
Definition utils.h:298
 
OffsetT location()
Definition utils.h:283
 
int dim
Definition utils.h:259
 
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:274
 
LoopedElemToLoc(int dim)
Definition utils.h:263
 
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:265
 
void next(const constant int *shape, const constant size_t *strides)
Definition utils.h:216
 
LoopedElemToLoc(int dim)
Definition utils.h:214
 
void next(int n, const constant int *shape, const constant size_t *strides)
Definition utils.h:229
 
LoopedElemToLoc< DIM - 1, OffsetT, General > inner_looper
Definition utils.h:210
 
OffsetT location()
Definition utils.h:252
 
int index
Definition utils.h:212
 
OffsetT offset
Definition utils.h:211
 
int dim
Definition utils.h:209
 
float imag
Definition complex.h:22
 
float real
Definition complex.h:21