MLX
Loading...
Searching...
No Matches
scatter.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
7template <
8 typename T,
9 typename IdxT,
10 typename Op,
11 int NIDX,
12 bool UPD_ROW_CONTIG,
13 int NWORK>
14METAL_FUNC void scatter_impl(
15 const device T* updates,
16 device mlx_atomic<T>* out,
17 const constant int* upd_shape,
18 const constant size_t* upd_strides,
19 const constant size_t& upd_ndim,
20 const constant size_t& upd_size,
21 const constant int* out_shape,
22 const constant size_t* out_strides,
23 const constant size_t& out_ndim,
24 const constant int* axes,
25 const constant size_t& idx_size,
26 const thread Indices<IdxT, NIDX>& indices,
27 uint2 gid [[thread_position_in_grid]]) {
28 Op op;
29
30 auto ind_idx = gid.y * NWORK;
31 size_t out_offset = 0;
32 if (upd_size > 1) {
33 out_offset =
34 elem_to_loc(gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
35 }
36
37 for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
38 size_t out_idx = out_offset;
39 for (int i = 0; i < NIDX; ++i) {
40 auto idx_loc = indices.row_contiguous[i]
41 ? ind_idx
43 ind_idx,
44 &indices.shapes[indices.ndim * i],
45 &indices.strides[indices.ndim * i],
46 indices.ndim);
47 auto ax = axes[i];
48 auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
49 out_idx += idx_val * out_strides[ax];
50 }
51 auto upd_idx = ind_idx * upd_size + gid.x;
52 if constexpr (!UPD_ROW_CONTIG) {
53 upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim);
54 }
55 op.atomic_update(out, updates[upd_idx], out_idx);
56 }
57}
METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
Definition utils.h:87
Op op
Definition binary.h:129
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size)
Definition indexing.h:17
METAL_FUNC void scatter_impl(const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const constant size_t &idx_size, const thread Indices< IdxT, NIDX > &indices, uint2 gid)
Definition scatter.h:14
Definition indexing.h:8
Definition atomic.h:25