15 const device T* updates,
17 const constant
int* upd_shape,
18 const constant
size_t* upd_strides,
19 const constant
size_t& upd_ndim,
20 const constant
size_t& upd_size,
21 const constant
int* out_shape,
22 const constant
size_t* out_strides,
23 const constant
size_t& out_ndim,
24 const constant
int* axes,
25 const constant
size_t& idx_size,
27 uint2 gid [[thread_position_in_grid]]) {
30 auto ind_idx = gid.y * NWORK;
31 size_t out_offset = 0;
34 elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
37 for (
int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
38 size_t out_idx = out_offset;
39 for (
int i = 0; i < NIDX; ++i) {
40 auto idx_loc = indices.row_contiguous[i]
44 &indices.shapes[indices.ndim * i],
45 &indices.strides[indices.ndim * i],
48 auto idx_val =
offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
49 out_idx += idx_val * out_strides[ax];
51 auto upd_idx = ind_idx * upd_size + gid.x;
52 if constexpr (!UPD_ROW_CONTIG) {
53 upd_idx =
elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
55 op.atomic_update(out, updates[upd_idx], out_idx);
METAL_FUNC void scatter_impl(const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const constant size_t &idx_size, const thread Indices< IdxT, NIDX > &indices, uint2 gid)
Definition scatter.h:14