3template <
typename T, 
typename U, 
typename Op>
 
    9    uint index [[thread_position_in_grid]]) {
 
   10  auto out = Op()(a[0], b[0]);
 
 
   15template <
typename T, 
typename U, 
typename Op>
 
   21    uint index [[thread_position_in_grid]]) {
 
   22  auto out = Op()(a[0], b[index]);
 
 
   27template <
typename T, 
typename U, 
typename Op>
 
   33    uint index [[thread_position_in_grid]]) {
 
   34  auto out = Op()(a[index], b[0]);
 
 
   39template <
typename T, 
typename U, 
typename Op>
 
   45    uint index [[thread_position_in_grid]]) {
 
   46  auto out = Op()(a[index], b[index]);
 
 
   51template <
typename T, 
typename U, 
typename Op>
 
   57    uint2 index [[thread_position_in_grid]],
 
   58    uint2 grid_dim [[threads_per_grid]]) {
 
   59  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
   60  auto out = Op()(a[0], b[offset]);
 
 
   65template <
typename T, 
typename U, 
typename Op>
 
   71    uint2 index [[thread_position_in_grid]],
 
   72    uint2 grid_dim [[threads_per_grid]]) {
 
   73  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
   74  auto out = Op()(a[offset], b[0]);
 
 
   79template <
typename T, 
typename U, 
typename Op>
 
   85    uint2 index [[thread_position_in_grid]],
 
   86    uint2 grid_dim [[threads_per_grid]]) {
 
   87  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
   88  auto out = Op()(a[offset], b[offset]);
 
 
   93template <
typename T, 
typename U, 
typename Op>
 
   99    constant 
const size_t& a_stride,
 
  100    constant 
const size_t& b_stride,
 
  101    uint index [[thread_position_in_grid]]) {
 
  104  auto out = Op()(a[a_idx], b[b_idx]);
 
 
  109template <
typename T, 
typename U, 
typename Op>
 
  115    constant 
const size_t a_strides[2],
 
  116    constant 
const size_t b_strides[2],
 
  117    uint2 index [[thread_position_in_grid]],
 
  118    uint2 grid_dim [[threads_per_grid]]) {
 
  121  size_t out_idx = index.x + size_t(grid_dim.x) * index.y;
 
  122  auto out = Op()(a[a_idx], b[b_idx]);
 
 
  127template <
typename T, 
typename U, 
typename Op>
 
  133    constant 
const size_t a_strides[3],
 
  134    constant 
const size_t b_strides[3],
 
  135    uint3 index [[thread_position_in_grid]],
 
  136    uint3 grid_dim [[threads_per_grid]]) {
 
  140      index.x + grid_dim.x * (index.y + size_t(grid_dim.y) * index.z);
 
  141  auto out = Op()(a[a_idx], b[b_idx]);
 
 
  146template <
typename T, 
typename U, 
typename Op, 
int N = 1>
 
  152    constant 
const int* shape,
 
  153    constant 
const size_t* a_strides,
 
  154    constant 
const size_t* b_strides,
 
  155    constant 
const int& ndim,
 
  156    uint3 index [[thread_position_in_grid]],
 
  157    uint3 grid_dim [[threads_per_grid]]) {
 
  159      {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim);
 
  160  auto xshape = shape[ndim - 1];
 
  162      N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
 
  163  auto a_xstride = a_strides[ndim - 1];
 
  164  auto b_xstride = b_strides[ndim - 1];
 
  165  for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
 
  166    auto out = Op()(a[idx.x], b[idx.y]);
 
  168    d[out_idx++] = out[1];