9    const device T* src [[buffer(0)]],
 
   10    device T* out [[buffer(1)]],
 
   11    const constant 
int* src_shape [[buffer(2)]],
 
   12    const constant 
size_t* src_strides [[buffer(3)]],
 
   13    const constant 
size_t& src_ndim [[buffer(4)]],
 
   14    const constant 
int* slice_sizes [[buffer(5)]],
 
   15    const constant 
int* axes [[buffer(6)]],
 
   17    uint3 index [[thread_position_in_grid]],
 
   18    uint3 grid_dim [[threads_per_grid]]) {
 
   20  for (
int i = 0; i < NIDX; ++i) {
 
   24    } 
else if (IDX_NDIM == 1) {
 
   25      idx_loc = index.x * indices.strides[indices.ndim * i];
 
   27      idx_loc = index.x * indices.strides[indices.ndim * i];
 
   28      idx_loc += indices.row_contiguous[i]
 
   32                &indices.shapes[indices.ndim * i + 1],
 
   33                &indices.strides[indices.ndim * i + 1],
 
   37    auto idx_val = 
offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
 
   38    src_idx += idx_val * src_strides[ax];
 
   41  auto src_offset = 
elem_to_loc(index.z, slice_sizes, src_strides, src_ndim);
 
   43  size_t out_idx = index.z;
 
   45    out_idx += 
static_cast<size_t>(grid_dim.z) * index.x;
 
   46  } 
else if (IDX_NDIM >= 2) {
 
   48        grid_dim.z * (index.x * 
static_cast<size_t>(grid_dim.y) + index.y);
 
   50  out[out_idx] = src[src_offset + src_idx];
 
 
METAL_FUNC void gather_impl(const device T *src, device T *out, const constant int *src_shape, const constant size_t *src_strides, const constant size_t &src_ndim, const constant int *slice_sizes, const constant int *axes, const thread Indices< IdxT, NIDX > &indices, uint3 index, uint3 grid_dim)
Definition gather.h:8