3template <
typename T, 
typename U, 
typename Op>
 
    8    uint index [[thread_position_in_grid]]) {
 
    9  c[index] = Op()(a[0], b[0]);
 
 
   12template <
typename T, 
typename U, 
typename Op>
 
   17    uint index [[thread_position_in_grid]]) {
 
   18  c[index] = Op()(a[0], b[index]);
 
 
   21template <
typename T, 
typename U, 
typename Op>
 
   26    uint index [[thread_position_in_grid]]) {
 
   27  c[index] = Op()(a[index], b[0]);
 
 
   30template <
typename T, 
typename U, 
typename Op>
 
   35    uint index [[thread_position_in_grid]]) {
 
   36  c[index] = Op()(a[index], b[index]);
 
 
   39template <
typename T, 
typename U, 
typename Op>
 
   44    uint2 index [[thread_position_in_grid]],
 
   45    uint2 grid_dim [[threads_per_grid]]) {
 
   46  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
   47  c[offset] = Op()(a[0], b[offset]);
 
 
   50template <
typename T, 
typename U, 
typename Op>
 
   55    uint2 index [[thread_position_in_grid]],
 
   56    uint2 grid_dim [[threads_per_grid]]) {
 
   57  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
   58  c[offset] = Op()(a[offset], b[0]);
 
 
   61template <
typename T, 
typename U, 
typename Op>
 
   66    uint2 index [[thread_position_in_grid]],
 
   67    uint2 grid_dim [[threads_per_grid]]) {
 
   68  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
   69  c[offset] = Op()(a[offset], b[offset]);
 
 
   72template <
typename T, 
typename U, 
typename Op>
 
   77    constant 
const size_t& a_stride,
 
   78    constant 
const size_t& b_stride,
 
   79    uint index [[thread_position_in_grid]]) {
 
   82  c[index] = Op()(a[a_idx], b[b_idx]);
 
 
   85template <
typename T, 
typename U, 
typename Op>
 
   90    constant 
const size_t a_strides[2],
 
   91    constant 
const size_t b_strides[2],
 
   92    uint2 index [[thread_position_in_grid]],
 
   93    uint2 grid_dim [[threads_per_grid]]) {
 
   96  size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
 
   97  c[out_idx] = Op()(a[a_idx], b[b_idx]);
 
 
  100template <
typename T, 
typename U, 
typename Op>
 
  105    constant 
const size_t a_strides[3],
 
  106    constant 
const size_t b_strides[3],
 
  107    uint3 index [[thread_position_in_grid]],
 
  108    uint3 grid_dim [[threads_per_grid]]) {
 
  112      index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
 
  113  c[out_idx] = Op()(a[a_idx], b[b_idx]);
 
 
  116template <
typename T, 
typename U, 
typename Op, 
int DIM>
 
  121    constant 
const int shape[DIM],
 
  122    constant 
const size_t a_strides[DIM],
 
  123    constant 
const size_t b_strides[DIM],
 
  124    uint3 index [[thread_position_in_grid]],
 
  125    uint3 grid_dim [[threads_per_grid]]) {
 
  126  auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
 
  128      index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
 
  129  c[out_idx] = Op()(a[idx.x], b[idx.y]);
 
 
  132template <
typename T, 
typename U, 
typename Op>
 
  137    constant 
const int* shape,
 
  138    constant 
const size_t* a_strides,
 
  139    constant 
const size_t* b_strides,
 
  140    constant 
const int& ndim,
 
  141    uint3 index [[thread_position_in_grid]],
 
  142    uint3 grid_dim [[threads_per_grid]]) {
 
  144  size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
 
  145  c[out_idx] = Op()(a[idx.x], b[idx.y]);