9    const device T* updates [[buffer(1)]],
 
   11    const constant 
int* out_shape [[buffer(3)]],
 
   12    const constant 
size_t* out_strides [[buffer(4)]],
 
   13    const constant 
size_t& out_ndim [[buffer(5)]],
 
   14    const constant 
int* upd_shape [[buffer(6)]],
 
   15    const constant 
size_t& upd_ndim [[buffer(7)]],
 
   16    const constant 
size_t& upd_size [[buffer(8)]],
 
   17    const thread array<const device IdxT*, NIDX>& idx_buffers,
 
   18    uint2 gid [[thread_position_in_grid]]) {
 
   22  for (
int i = 0; i < NIDX; i++) {
 
   23    auto idx_val = 
offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
 
   24    out_idx += idx_val * out_strides[i];
 
   28    auto out_offset = 
elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
 
   29    out_idx += out_offset;
 
   34  op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
 
 
   39    const device T* updates [[buffer(1)]],
 
   41    const constant 
int* upd_shape [[buffer(3)]],
 
   42    const constant 
size_t* upd_strides [[buffer(4)]],
 
   43    const constant 
size_t& upd_ndim [[buffer(5)]],
 
   44    const constant 
size_t& upd_size [[buffer(6)]],
 
   45    const constant 
int* out_shape [[buffer(7)]],
 
   46    const constant 
size_t* out_strides [[buffer(8)]],
 
   47    const constant 
size_t& out_ndim [[buffer(9)]],
 
   48    const constant 
int* axes [[buffer(10)]],
 
   50    uint2 gid [[thread_position_in_grid]]) {
 
   53  auto ind_offset = gid.x;
 
   56  for (
int i = 0; i < NIDX; ++i) {
 
   59        &indices.shapes[indices.ndim * i],
 
   60        &indices.strides[indices.ndim * i],
 
   63    auto idx_val = 
offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
 
   64    out_idx += idx_val * out_strides[ax];
 
   69        ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
 
   70    out_idx += out_offset;
 
   74      elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
 
   75  op.atomic_update(out, updates[upd_idx], out_idx);
 
 
METAL_FUNC void scatter_impl(const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const thread Indices< IdxT, NIDX > &indices, uint2 gid)
Definition scatter.h:38
 
METAL_FUNC void scatter_1d_index_impl(const device T *updates, device mlx_atomic< T > *out, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *upd_shape, const constant size_t &upd_ndim, const constant size_t &upd_size, const thread array< const device IdxT *, NIDX > &idx_buffers, uint2 gid)
Definition scatter.h:8