15  static const constant U 
max = metal::numeric_limits<U>::max();
 
   16  static const constant U 
min = metal::numeric_limits<U>::min();
 
   17  static const constant U 
finite_max = metal::numeric_limits<U>::max();
 
   18  static const constant U 
finite_min = metal::numeric_limits<U>::min();
 
 
   21#define instantiate_default_limit(type)                                      \ 
   23  struct Limits<type> {                                                      \ 
   24    static constexpr constant type max = metal::numeric_limits<type>::max(); \ 
   25    static constexpr constant type min = metal::numeric_limits<type>::min(); \ 
   26    static constexpr constant type finite_max =                              \ 
   27        metal::numeric_limits<type>::max();                                  \ 
   28    static constexpr constant type finite_min =                              \ 
   29        metal::numeric_limits<type>::min();                                  \ 
 
   41#define instantiate_float_limit(type)             \ 
   43  struct Limits<type> {                           \ 
   44    static constexpr constant type max =          \ 
   45        metal::numeric_limits<type>::infinity();  \ 
   46    static constexpr constant type min =          \ 
   47        -metal::numeric_limits<type>::infinity(); \ 
   48    static constexpr constant type finite_max =   \ 
   49        metal::numeric_limits<type>::max();       \ 
   50    static constexpr constant type finite_min =   \ 
   51        -metal::numeric_limits<type>::max();      \ 
 
   60  static constexpr constant 
bool max = 
true;
 
   61  static constexpr constant 
bool min = 
false;
 
 
   68#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") 
   73template <
typename str
ide_t>
 
   76    device 
const int* shape,
 
   77    device 
const stride_t* strides,
 
   80  for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
 
   81    loc += (elem % shape[i]) * strides[i];
 
 
   87template <
typename str
ide_t>
 
   90    constant 
const int* shape,
 
   91    constant 
const stride_t* strides,
 
   94  for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
 
   95    loc += (elem % shape[i]) * strides[i];
 
 
  102template <
typename str
ide_t>
 
  105    constant 
const int* shape,
 
  106    constant 
const stride_t* strides,
 
  108  stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
 
  109  for (
int d = ndim - 3; d >= 0; --d) {
 
  110    loc += (elem.z % shape[d]) * strides[d];
 
 
  119template <
typename str
ide_t>
 
  120METAL_FUNC stride_t 
elem_to_loc_1(uint elem, constant 
const stride_t& stride) {
 
  121  return elem * stride;
 
 
  124template <
typename str
ide_t>
 
  127  return elem.x * strides[1] + elem.y * strides[0];
 
 
  130template <
typename str
ide_t>
 
  133  return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
 
 
  139    device 
const int* shape,
 
  140    device 
const size_t* strides) {
 
  141  size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
 
  144  for (
int d = NDIM - 2; d >= 0; --d) {
 
  145    elem /= shape[d + 1];
 
  146    loc += (elem % shape[d]) * strides[d];
 
 
  155    constant 
const int shape[NDIM],
 
  156    constant 
const size_t strides[NDIM]) {
 
  157  size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
 
  158  for (
int d = NDIM - 3; d >= 0; --d) {
 
  159    loc += (elem.z % shape[d]) * strides[d];
 
 
  168    constant 
const int shape[NDIM],
 
  169    constant 
const int64_t strides[NDIM]) {
 
  170  int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
 
  173  for (
int d = NDIM - 2; d >= 0; --d) {
 
  174    elem /= shape[d + 1];
 
  175    loc += (elem % shape[d]) * strides[d];
 
 
  184    constant 
const int shape[NDIM],
 
  185    constant 
const int64_t strides[NDIM]) {
 
  186  int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
 
  187  for (
int d = NDIM - 3; d >= 0; --d) {
 
  188    loc += (elem.z % shape[d]) * strides[d];
 
 
  199    constant 
const int* shape,
 
  200    constant 
const size_t* a_strides,
 
  201    constant 
const size_t* b_strides,
 
  205          elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
 
  207          elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
 
  208  for (
int d = ndim - 3; d >= 0; --d) {
 
  209    uint l = elem.z % shape[d];
 
  210    loc.x += l * a_strides[d];
 
  211    loc.y += l * b_strides[d];
 
 
  219    constant 
const int* shape,
 
  220    constant 
const size_t* a_strides,
 
  221    constant 
const size_t* b_strides,
 
  222    constant 
const size_t* c_strides,
 
  226          elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
 
  228          elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
 
  230          elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
 
  231  for (
int d = ndim - 3; d >= 0; --d) {
 
  232    uint l = elem.z % shape[d];
 
  233    loc.x += l * a_strides[d];
 
  234    loc.y += l * b_strides[d];
 
  235    loc.z += l * c_strides[d];
 
 
  247    constant 
const int shape[NDIM],
 
  248    constant 
const size_t a_strides[NDIM],
 
  249    constant 
const size_t b_strides[NDIM]) {
 
  252          elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
 
  254          elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
 
  255  for (
int d = NDIM - 3; d >= 0; --d) {
 
  256    uint l = elem.z % shape[d];
 
  257    loc.x += l * a_strides[d];
 
  258    loc.y += l * b_strides[d];
 
 
  267    constant 
const int shape[NDIM],
 
  268    constant 
const size_t a_strides[NDIM],
 
  269    constant 
const size_t b_strides[NDIM],
 
  270    constant 
const size_t c_strides[NDIM]) {
 
  273          elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
 
  275          elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
 
  277          elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
 
  278  for (
int d = NDIM - 3; d >= 0; --d) {
 
  279    uint l = elem.z % shape[d];
 
  280    loc.x += l * a_strides[d];
 
  281    loc.y += l * b_strides[d];
 
  282    loc.z += l * c_strides[d];
 
 
  294  return (N + M - 1) / M;
 
 
  299  float xp1 = 1.0f + x;
 
 
  311  float xp1 = 1.0f + 
static_cast<float>(x);
 
 
  327  return as_type<uint64_t>(
 
 
  332  return as_type<int64_t>(
 
 
static const constant U max
Definition utils.h:15
 
static const constant U finite_max
Definition utils.h:17
 
static const constant U min
Definition utils.h:16
 
static const constant U finite_min
Definition utils.h:18