3template <
typename T, 
typename U, 
typename Op>
 
    8    uint index [[thread_position_in_grid]]) {
 
    9  c[index] = Op()(a[0], b[0]);
 
 
   12template <
typename T, 
typename U, 
typename Op>
 
   17    uint index [[thread_position_in_grid]]) {
 
   18  c[index] = Op()(a[0], b[index]);
 
 
   21template <
typename T, 
typename U, 
typename Op>
 
   26    uint index [[thread_position_in_grid]]) {
 
   27  c[index] = Op()(a[index], b[0]);
 
 
   30template <
typename T, 
typename U, 
typename Op>
 
   35    uint index [[thread_position_in_grid]]) {
 
   36  c[index] = Op()(a[index], b[index]);
 
 
   39template <
typename T, 
typename U, 
typename Op>
 
   44    constant 
const size_t& a_stride,
 
   45    constant 
const size_t& b_stride,
 
   46    uint index [[thread_position_in_grid]]) {
 
   49  c[index] = Op()(a[a_idx], b[b_idx]);
 
 
   52template <
typename T, 
typename U, 
typename Op>
 
   57    constant 
const size_t a_strides[2],
 
   58    constant 
const size_t b_strides[2],
 
   59    uint2 index [[thread_position_in_grid]],
 
   60    uint2 grid_dim [[threads_per_grid]]) {
 
   63  size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
 
   64  c[out_idx] = Op()(a[a_idx], b[b_idx]);
 
 
   67template <
typename T, 
typename U, 
typename Op>
 
   72    constant 
const size_t a_strides[3],
 
   73    constant 
const size_t b_strides[3],
 
   74    uint3 index [[thread_position_in_grid]],
 
   75    uint3 grid_dim [[threads_per_grid]]) {
 
   79      index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
 
   80  c[out_idx] = Op()(a[a_idx], b[b_idx]);
 
 
   83template <
typename T, 
typename U, 
typename Op, 
int DIM>
 
   88    constant 
const int shape[DIM],
 
   89    constant 
const size_t a_strides[DIM],
 
   90    constant 
const size_t b_strides[DIM],
 
   91    uint3 index [[thread_position_in_grid]],
 
   92    uint3 grid_dim [[threads_per_grid]]) {
 
   93  auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
 
   95      index.x + (size_t)grid_dim.x * (index.y + (
size_t)grid_dim.y * index.z);
 
   96  c[out_idx] = Op()(a[idx.x], b[idx.y]);
 
 
   99template <
typename T, 
typename U, 
typename Op>
 
  104    constant 
const int* shape,
 
  105    constant 
const size_t* a_strides,
 
  106    constant 
const size_t* b_strides,
 
  107    constant 
const int& ndim,
 
  108    uint3 index [[thread_position_in_grid]],
 
  109    uint3 grid_dim [[threads_per_grid]]) {
 
  111  size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
 
  112  c[out_idx] = Op()(a[idx.x], b[idx.y]);