9 const device T* src [[buffer(0)]],
10 device T* out [[buffer(1)]],
11 const constant
int* src_shape [[buffer(2)]],
12 const constant
size_t* src_strides [[buffer(3)]],
13 const constant
size_t& src_ndim [[buffer(4)]],
14 const constant
int* slice_sizes [[buffer(5)]],
15 const constant
int* axes [[buffer(6)]],
17 uint3 index [[thread_position_in_grid]],
18 uint3 grid_dim [[threads_per_grid]]) {
20 for (
int i = 0; i < NIDX; ++i) {
24 }
else if (IDX_NDIM == 1) {
25 idx_loc = index.x *
static_cast<LocT
>(indices.strides[indices.ndim * i]);
27 idx_loc = index.x *
static_cast<LocT
>(indices.strides[indices.ndim * i]);
28 idx_loc += indices.row_contiguous[i]
32 &indices.shapes[indices.ndim * i + 1],
33 &indices.strides[indices.ndim * i + 1],
37 auto idx_val =
offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
38 src_idx +=
static_cast<LocT
>(idx_val) *
static_cast<LocT
>(src_strides[ax]);
44 LocT out_idx = index.z;
46 out_idx +=
static_cast<LocT
>(grid_dim.z) * index.x;
47 }
else if (IDX_NDIM >= 2) {
48 out_idx += grid_dim.z * (index.x *
static_cast<LocT
>(grid_dim.y) + index.y);
50 out[out_idx] = src[src_offset + src_idx];
METAL_FUNC void gather_impl(const device T *src, device T *out, const constant int *src_shape, const constant size_t *src_strides, const constant size_t &src_ndim, const constant int *slice_sizes, const constant int *axes, const thread Indices< IdxT, NIDX > &indices, uint3 index, uint3 grid_dim)
Definition gather.h:8