7 const device T* src [[buffer(0)]],
8 const device IdxT* indices [[buffer(1)]],
9 device T* out [[buffer(2)]],
10 const constant
int* shape [[buffer(3)]],
11 const constant int64_t* src_strides [[buffer(4)]],
12 const constant int64_t* idx_strides [[buffer(5)]],
13 const constant
size_t& ndim [[buffer(6)]],
14 const constant
int& axis [[buffer(7)]],
15 const constant
int& axis_size [[buffer(8)]],
16 const constant
size_t& src_ax_stride [[buffer(9)]],
17 const constant
size_t& idx_ax_stride [[buffer(10)]],
18 uint3 index [[thread_position_in_grid]],
19 uint3 grid_dim [[threads_per_grid]]) {
20 LocT elem_idx = index.z *
static_cast<LocT
>(grid_dim.x);
21 LocT out_idx = elem_idx * grid_dim.y + index.x;
23 LocT idx_loc = index.y *
static_cast<LocT
>(idx_ax_stride);
30 auto idx_val = indices[idx_loc];
31 if (is_signed_v<IdxT>) {
32 idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val;
35 LocT src_idx = idx_val *
static_cast<LocT
>(src_ax_stride);
37 src_idx += elem_idx * axis_size + index.x;
42 out_idx += index.y *
static_cast<LocT
>(grid_dim.x);
43 out[out_idx] = src[src_idx];
void gather_axis(const device T *src, const device IdxT *indices, device T *out, const constant int *shape, const constant int64_t *src_strides, const constant int64_t *idx_strides, const constant size_t &ndim, const constant int &axis, const constant int &axis_size, const constant size_t &src_ax_stride, const constant size_t &idx_ax_stride, uint3 index, uint3 grid_dim)
Definition gather_axis.h:6