5    device {1}* out [[buffer(0)]], 
    6    uint tid [[thread_position_in_grid]]) {{ 
    7  out[tid] = {2}<{1}>::init; 
   12template [[host_name("all_{0}")]] [[kernel]] void 
   13all_reduce<{1}, {2}, {3}<{2}>>( 
   14    const device {1}* in [[buffer(0)]], 
   15    device mlx_atomic<{2}>* out [[buffer(1)]], 
   16    const device size_t& in_size [[buffer(2)]], 
   17    uint gid [[thread_position_in_grid]], 
   18    uint lid [[thread_position_in_threadgroup]], 
   19    uint grid_size [[threads_per_grid]], 
   20    uint simd_per_group [[simdgroups_per_threadgroup]], 
   21    uint simd_lane_id [[thread_index_in_simdgroup]], 
   22    uint simd_group_id [[simdgroup_index_in_threadgroup]]); 
   23template [[host_name("colGeneral_{0}")]] [[kernel]] void 
   24col_reduce_general<{1}, {2}, {3}<{2}>>( 
   25    const device {1}* in [[buffer(0)]], 
   26    device mlx_atomic<{2}>* out [[buffer(1)]], 
   27    const constant size_t& reduction_size [[buffer(2)]], 
   28    const constant size_t& reduction_stride [[buffer(3)]], 
   29    const constant size_t& out_size [[buffer(4)]], 
   30    const constant int* shape [[buffer(5)]], 
   31    const constant size_t* strides [[buffer(6)]], 
   32    const constant int& ndim [[buffer(7)]], 
   33    threadgroup {2}* local_data [[threadgroup(0)]], 
   34    uint3 tid [[threadgroup_position_in_grid]], 
   35    uint3 lid [[thread_position_in_threadgroup]], 
   36    uint3 lsize [[threads_per_threadgroup]]); 
   37template [[host_name("colSmall_{0}")]] [[kernel]] void 
   38col_reduce_small<{1}, {2}, {3}<{2}>>( 
   39    const device {1}* in [[buffer(0)]], 
   40    device {2}* out [[buffer(1)]], 
   41    const constant size_t& reduction_size [[buffer(2)]], 
   42    const constant size_t& reduction_stride [[buffer(3)]], 
   43    const constant size_t& out_size [[buffer(4)]], 
   44    const constant int* shape [[buffer(5)]], 
   45    const constant size_t* strides [[buffer(6)]], 
   46    const constant int& ndim [[buffer(7)]], 
   47    const constant size_t& non_col_reductions [[buffer(8)]], 
   48    const constant int* non_col_shapes [[buffer(9)]], 
   49    const constant size_t* non_col_strides [[buffer(10)]], 
   50    const constant int& non_col_ndim [[buffer(11)]], 
   51    uint tid [[thread_position_in_grid]]); 
   52template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void 
   53row_reduce_general_small<{1}, {2}, {3}<{2}>>( 
   54    const device {1}* in [[buffer(0)]], 
   55    device {2}* out [[buffer(1)]], 
   56    const constant size_t& reduction_size [[buffer(2)]], 
   57    const constant size_t& out_size [[buffer(3)]], 
   58    const constant size_t& non_row_reductions [[buffer(4)]], 
   59    const constant int* shape [[buffer(5)]], 
   60    const constant size_t* strides [[buffer(6)]], 
   61    const constant int& ndim [[buffer(7)]], 
   62    uint lid [[thread_position_in_grid]]); 
   63template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void 
   64row_reduce_general_med<{1}, {2}, {3}<{2}>>( 
   65    const device {1}* in [[buffer(0)]], 
   66    device {2}* out [[buffer(1)]], 
   67    const constant size_t& reduction_size [[buffer(2)]], 
   68    const constant size_t& out_size [[buffer(3)]], 
   69    const constant size_t& non_row_reductions [[buffer(4)]], 
   70    const constant int* shape [[buffer(5)]], 
   71    const constant size_t* strides [[buffer(6)]], 
   72    const constant int& ndim [[buffer(7)]], 
   73    uint tid [[threadgroup_position_in_grid]], 
   74    uint simd_lane_id [[thread_index_in_simdgroup]], 
   75    uint simd_per_group [[dispatch_simdgroups_per_threadgroup]], 
   76    uint simd_group_id [[simdgroup_index_in_threadgroup]]); 
   77template [[host_name("rowGeneral_{0}")]] [[kernel]] void 
   78row_reduce_general<{1}, {2}, {3}<{2}>>( 
   79    const device {1}* in [[buffer(0)]], 
   80    device mlx_atomic<{2}>* out [[buffer(1)]], 
   81    const constant size_t& reduction_size [[buffer(2)]], 
   82    const constant size_t& out_size [[buffer(3)]], 
   83    const constant size_t& non_row_reductions [[buffer(4)]], 
   84    const constant int* shape [[buffer(5)]], 
   85    const constant size_t* strides [[buffer(6)]], 
   86    const constant int& ndim [[buffer(7)]], 
   87    uint3 lid [[thread_position_in_threadgroup]], 
   88    uint3 lsize [[threads_per_threadgroup]], 
   89    uint3 tid [[threadgroup_position_in_grid]], 
   90    uint simd_lane_id [[thread_index_in_simdgroup]], 
   91    uint simd_per_group [[simdgroups_per_threadgroup]], 
   92    uint simd_group_id [[simdgroup_index_in_threadgroup]]); 
   96template [[host_name("allNoAtomics_{0}")]] [[kernel]] void 
   97all_reduce_no_atomics<{1}, {2}, {3}<{2}>>( 
   98    const device {1}* in [[buffer(0)]], 
   99    device {2}* out [[buffer(1)]], 
  100    const device size_t& in_size [[buffer(2)]], 
  101    uint gid [[thread_position_in_grid]], 
  102    uint lid [[thread_position_in_threadgroup]], 
  103    uint grid_size [[threads_per_grid]], 
  104    uint simd_per_group [[simdgroups_per_threadgroup]], 
  105    uint simd_lane_id [[thread_index_in_simdgroup]], 
  106    uint simd_group_id [[simdgroup_index_in_threadgroup]], 
  107    uint thread_group_id [[threadgroup_position_in_grid]]); 
  109template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void 
  110  col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>( 
  111      const device {1}* in [[buffer(0)]], 
  112      device {2}* out [[buffer(1)]], 
  113      const constant size_t& reduction_size [[buffer(2)]], 
  114      const constant size_t& reduction_stride [[buffer(3)]], 
  115      const constant size_t& out_size [[buffer(4)]], 
  116      const constant int* shape [[buffer(5)]], 
  117      const constant size_t* strides [[buffer(6)]], 
  118      const constant int& ndim [[buffer(7)]], 
  119      threadgroup {2}* local_data [[threadgroup(0)]], 
  120      uint3 tid [[threadgroup_position_in_grid]], 
  121      uint3 lid [[thread_position_in_threadgroup]], 
  122      uint3 gid [[thread_position_in_grid]], 
  123      uint3 lsize [[threads_per_threadgroup]], 
  124      uint3 gsize [[threads_per_grid]]); 
  125template [[host_name("colSmall_{0}")]] [[kernel]] void 
  126col_reduce_small<{1}, {2}, {3}<{2}>>( 
  127    const device {1}* in [[buffer(0)]], 
  128    device {2}* out [[buffer(1)]], 
  129    const constant size_t& reduction_size [[buffer(2)]], 
  130    const constant size_t& reduction_stride [[buffer(3)]], 
  131    const constant size_t& out_size [[buffer(4)]], 
  132    const constant int* shape [[buffer(5)]], 
  133    const constant size_t* strides [[buffer(6)]], 
  134    const constant int& ndim [[buffer(7)]], 
  135    const constant size_t& non_col_reductions [[buffer(8)]], 
  136    const constant int* non_col_shapes [[buffer(9)]], 
  137    const constant size_t* non_col_strides [[buffer(10)]], 
  138    const constant int& non_col_ndim [[buffer(11)]], 
  139    uint tid [[thread_position_in_grid]]); 
  140template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void 
  141row_reduce_general_small<{1}, {2}, {3}<{2}>>( 
  142    const device {1}* in [[buffer(0)]], 
  143    device {2}* out [[buffer(1)]], 
  144    const constant size_t& reduction_size [[buffer(2)]], 
  145    const constant size_t& out_size [[buffer(3)]], 
  146    const constant size_t& non_row_reductions [[buffer(4)]], 
  147    const constant int* shape [[buffer(5)]], 
  148    const constant size_t* strides [[buffer(6)]], 
  149    const constant int& ndim [[buffer(7)]], 
  150    uint lid [[thread_position_in_grid]]); 
  151template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void 
  152row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>( 
  153    const device {1}* in [[buffer(0)]], 
  154    device {2}* out [[buffer(1)]], 
  155    const constant size_t& reduction_size [[buffer(2)]], 
  156    const constant size_t& out_size [[buffer(3)]], 
  157    const constant size_t& non_row_reductions [[buffer(4)]], 
  158    const constant int* shape [[buffer(5)]], 
  159    const constant size_t* strides [[buffer(6)]], 
  160    const constant int& ndim [[buffer(7)]], 
  161    uint3 lid [[thread_position_in_threadgroup]], 
  162    uint3 lsize [[threads_per_threadgroup]], 
  163    uint3 gsize [[threads_per_grid]], 
  164    uint3 tid [[threadgroup_position_in_grid]], 
  165    uint simd_lane_id [[thread_index_in_simdgroup]], 
  166    uint simd_per_group [[simdgroups_per_threadgroup]], 
  167    uint simd_group_id [[simdgroup_index_in_threadgroup]]);