18  static const constant U 
max = metal::numeric_limits<U>::max();
 
   19  static const constant U 
min = metal::numeric_limits<U>::min();
 
   20  static const constant U 
finite_max = metal::numeric_limits<U>::max();
 
   21  static const constant U 
finite_min = metal::numeric_limits<U>::min();
 
 
   24#define instantiate_default_limit(type)                                      \ 
   26  struct Limits<type> {                                                      \ 
   27    static constexpr constant type max = metal::numeric_limits<type>::max(); \ 
   28    static constexpr constant type min = metal::numeric_limits<type>::min(); \ 
   29    static constexpr constant type finite_max =                              \ 
   30        metal::numeric_limits<type>::max();                                  \ 
   31    static constexpr constant type finite_min =                              \ 
   32        metal::numeric_limits<type>::min();                                  \ 
 
   44#define instantiate_float_limit(type)             \ 
   46  struct Limits<type> {                           \ 
   47    static constexpr constant type max =          \ 
   48        metal::numeric_limits<type>::infinity();  \ 
   49    static constexpr constant type min =          \ 
   50        -metal::numeric_limits<type>::infinity(); \ 
   51    static constexpr constant type finite_max =   \ 
   52        metal::numeric_limits<type>::max();       \ 
   53    static constexpr constant type finite_min =   \ 
   54        -metal::numeric_limits<type>::max();      \ 
 
   63  static constexpr constant 
bool max = 
true;
 
   64  static constexpr constant 
bool min = 
false;
 
 
   71#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") 
   76template <
typename str
ide_t>
 
   79    device 
const int* shape,
 
   80    device 
const stride_t* strides,
 
   83  for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
 
   84    loc += (elem % shape[i]) * strides[i];
 
 
   90template <
typename str
ide_t>
 
   93    constant 
const int* shape,
 
   94    constant 
const stride_t* strides,
 
   97  for (
int i = ndim - 1; i >= 0 && elem > 0; --i) {
 
   98    loc += (elem % shape[i]) * strides[i];
 
 
  105template <
typename str
ide_t>
 
  108    constant 
const int* shape,
 
  109    constant 
const stride_t* strides,
 
  111  stride_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
 
  112  for (
int d = ndim - 3; d >= 0; --d) {
 
  113    loc += (elem.z % shape[d]) * strides[d];
 
 
  122template <
typename str
ide_t>
 
  123METAL_FUNC stride_t 
elem_to_loc_1(uint elem, constant 
const stride_t& stride) {
 
  124  return elem * stride;
 
 
  127template <
typename str
ide_t>
 
  130  return elem.x * strides[1] + elem.y * strides[0];
 
 
  133template <
typename str
ide_t>
 
  136  return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
 
 
  142    device 
const int* shape,
 
  143    device 
const size_t* strides) {
 
  144  size_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
 
  147  for (
int d = NDIM - 2; d >= 0; --d) {
 
  148    elem /= shape[d + 1];
 
  149    loc += (elem % shape[d]) * strides[d];
 
 
  158    constant 
const int shape[NDIM],
 
  159    constant 
const size_t strides[NDIM]) {
 
  160  size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
 
  161  for (
int d = NDIM - 3; d >= 0; --d) {
 
  162    loc += (elem.z % shape[d]) * strides[d];
 
 
  171    constant 
const int shape[NDIM],
 
  172    constant 
const int64_t strides[NDIM]) {
 
  173  int64_t loc = (elem % shape[NDIM - 1]) * strides[NDIM - 1];
 
  176  for (
int d = NDIM - 2; d >= 0; --d) {
 
  177    elem /= shape[d + 1];
 
  178    loc += (elem % shape[d]) * strides[d];
 
 
  187    constant 
const int shape[NDIM],
 
  188    constant 
const int64_t strides[NDIM]) {
 
  189  int64_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
 
  190  for (
int d = NDIM - 3; d >= 0; --d) {
 
  191    loc += (elem.z % shape[d]) * strides[d];
 
 
  202    constant 
const int* shape,
 
  203    constant 
const size_t* a_strides,
 
  204    constant 
const size_t* b_strides,
 
  208          elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
 
  210          elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
 
  211  for (
int d = ndim - 3; d >= 0; --d) {
 
  212    uint l = elem.z % shape[d];
 
  213    loc.x += l * a_strides[d];
 
  214    loc.y += l * b_strides[d];
 
 
  222    constant 
const int* shape,
 
  223    constant 
const size_t* a_strides,
 
  224    constant 
const size_t* b_strides,
 
  225    constant 
const size_t* c_strides,
 
  229          elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
 
  231          elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2]),
 
  233          elem.x * c_strides[ndim - 1] + elem.y * c_strides[ndim - 2])};
 
  234  for (
int d = ndim - 3; d >= 0; --d) {
 
  235    uint l = elem.z % shape[d];
 
  236    loc.x += l * a_strides[d];
 
  237    loc.y += l * b_strides[d];
 
  238    loc.z += l * c_strides[d];
 
 
  250    constant 
const int shape[NDIM],
 
  251    constant 
const size_t a_strides[NDIM],
 
  252    constant 
const size_t b_strides[NDIM]) {
 
  255          elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
 
  257          elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
 
  258  for (
int d = NDIM - 3; d >= 0; --d) {
 
  259    uint l = elem.z % shape[d];
 
  260    loc.x += l * a_strides[d];
 
  261    loc.y += l * b_strides[d];
 
 
  270    constant 
const int shape[NDIM],
 
  271    constant 
const size_t a_strides[NDIM],
 
  272    constant 
const size_t b_strides[NDIM],
 
  273    constant 
const size_t c_strides[NDIM]) {
 
  276          elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
 
  278          elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2]),
 
  280          elem.x * c_strides[NDIM - 1] + elem.y * c_strides[NDIM - 2])};
 
  281  for (
int d = NDIM - 3; d >= 0; --d) {
 
  282    uint l = elem.z % shape[d];
 
  283    loc.x += l * a_strides[d];
 
  284    loc.y += l * b_strides[d];
 
  285    loc.z += l * c_strides[d];
 
 
  297  return (N + M - 1) / M;
 
 
  302  float xp1 = 1.0f + x;
 
 
  314  float xp1 = 1.0f + 
static_cast<float>(x);
 
 
  330  return as_type<uint64_t>(
 
 
  335  return as_type<int64_t>(
 
 
static const constant U max
Definition utils.h:18
 
static const constant U finite_max
Definition utils.h:20
 
static const constant U min
Definition utils.h:19
 
static const constant U finite_min
Definition utils.h:21