16 const device T* updates,
18 const constant
int* upd_shape,
19 const constant
size_t* upd_strides,
20 const constant
size_t& upd_ndim,
21 const constant
size_t& upd_size,
22 const constant
int* out_shape,
23 const constant
size_t* out_strides,
24 const constant
size_t& out_ndim,
25 const constant
int* axes,
26 const constant
size_t& idx_size,
28 uint2 gid [[thread_position_in_grid]]) {
31 auto ind_idx = gid.y * NWORK;
35 gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
38 for (
int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
39 LocT out_idx = out_offset;
40 for (
int i = 0; i < NIDX; ++i) {
41 auto idx_loc = indices.row_contiguous[i]
45 &indices.shapes[indices.ndim * i],
46 &indices.strides[indices.ndim * i],
49 auto idx_val =
offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
51 static_cast<LocT
>(idx_val) *
static_cast<LocT
>(out_strides[ax]);
53 auto upd_idx = ind_idx *
static_cast<LocT
>(upd_size) + gid.x;
54 if constexpr (!UPD_ROW_CONTIG) {
58 op.atomic_update(out, updates[upd_idx], out_idx);
METAL_FUNC void scatter_impl(const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const constant size_t &idx_size, const thread Indices< IdxT, NIDX > &indices, uint2 gid)
Definition scatter.h:15