13 const device T* upd [[buffer(0)]],
14 const device IdxT* indices [[buffer(1)]],
16 const constant
int* shape [[buffer(3)]],
17 const constant int64_t* upd_strides [[buffer(4)]],
18 const constant int64_t* idx_strides [[buffer(5)]],
19 const constant
size_t& ndim [[buffer(6)]],
20 const constant
int& axis [[buffer(7)]],
21 const constant
int& out_axis_size [[buffer(8)]],
22 const constant
size_t& upd_ax_stride [[buffer(9)]],
23 const constant
size_t& idx_ax_stride [[buffer(10)]],
24 uint3 index [[thread_position_in_grid]],
25 uint3 grid_dim [[threads_per_grid]]) {
28 LocT elem_idx = index.z *
static_cast<LocT
>(grid_dim.x);
30 LocT idx_loc = index.y *
static_cast<LocT
>(idx_ax_stride);
32 idx_loc += elem_idx * grid_dim.y + index.x;
37 auto idx_val = indices[idx_loc];
38 if (is_signed_v<IdxT>) {
39 idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val;
42 LocT upd_idx = index.y *
static_cast<LocT
>(upd_ax_stride);
44 upd_idx += elem_idx * grid_dim.y + index.x;
49 LocT out_idx = elem_idx *
static_cast<LocT
>(out_axis_size) +
50 idx_val * grid_dim.x + index.x;
51 op.atomic_update(out, upd[upd_idx], out_idx);
void scatter_axis(const device T *upd, const device IdxT *indices, device mlx_atomic< T > *out, const constant int *shape, const constant int64_t *upd_strides, const constant int64_t *idx_strides, const constant size_t &ndim, const constant int &axis, const constant int &out_axis_size, const constant size_t &upd_ax_stride, const constant size_t &idx_ax_stride, uint3 index, uint3 grid_dim)
Definition scatter_axis.h:12