3template <
typename T, 
typename U, 
typename Op>
 
    7    uint index [[thread_position_in_grid]]) {
 
    8  out[index] = Op()(in[index]);
 
 
   11template <
typename T, 
typename U, 
typename Op>
 
   15    uint2 index [[thread_position_in_grid]],
 
   16    uint2 grid_dim [[threads_per_grid]]) {
 
   17  size_t offset = index.x + grid_dim.x * size_t(index.y);
 
   18  out[offset] = Op()(in[offset]);
 
 
   21template <
typename T, 
typename U, 
typename Op, 
int N = 1>
 
   25    constant 
const int* in_shape,
 
   26    constant 
const size_t* in_strides,
 
   27    device 
const int& ndim,
 
   28    uint3 index [[thread_position_in_grid]],
 
   29    uint3 grid_dim [[threads_per_grid]]) {
 
   31      elem_to_loc({N * index.x, index.y, index.z}, in_shape, in_strides, ndim);
 
   32  auto xshape = in_shape[ndim - 1];
 
   33  auto xstride = in_strides[ndim - 1];
 
   35      N * index.x + xshape * (index.y + size_t(grid_dim.y) * index.z);
 
   36  for (
int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) {
 
   37    out[out_idx++] = Op()(in[idx]);