MLX
Loading...
Searching...
No Matches
scatter.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
7template <
8 typename T,
9 typename IdxT,
10 typename Op,
11 int NIDX,
12 bool UPD_ROW_CONTIG,
13 int NWORK,
14 typename LocT>
15METAL_FUNC void scatter_impl(
16 const device T* updates,
17 device mlx_atomic<T>* out,
18 const constant int* upd_shape,
19 const constant size_t* upd_strides,
20 const constant size_t& upd_ndim,
21 const constant size_t& upd_size,
22 const constant int* out_shape,
23 const constant size_t* out_strides,
24 const constant size_t& out_ndim,
25 const constant int* axes,
26 const constant size_t& idx_size,
27 const thread Indices<IdxT, NIDX>& indices,
28 uint2 gid [[thread_position_in_grid]]) {
29 Op op;
30
31 auto ind_idx = gid.y * NWORK;
32 LocT out_offset = 0;
33 if (upd_size > 1) {
34 out_offset = elem_to_loc<size_t, LocT>(
35 gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
36 }
37
38 for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
39 LocT out_idx = out_offset;
40 for (int i = 0; i < NIDX; ++i) {
41 auto idx_loc = indices.row_contiguous[i]
42 ? ind_idx
44 ind_idx,
45 &indices.shapes[indices.ndim * i],
46 &indices.strides[indices.ndim * i],
47 indices.ndim);
48 auto ax = axes[i];
49 auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
50 out_idx +=
51 static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
52 }
53 auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
54 if constexpr (!UPD_ROW_CONTIG) {
55 upd_idx =
56 elem_to_loc<size_t, LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
57 }
58 op.atomic_update(out, updates[upd_idx], out_idx);
59 }
60}
METAL_FUNC IdxT elem_to_loc(uint elem, constant const int *shape, constant const StrideT *strides, int ndim)
Definition utils.h:93
Op op
Definition binary.h:129
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size)
Definition indexing.h:17
METAL_FUNC void scatter_impl(const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const constant size_t &idx_size, const thread Indices< IdxT, NIDX > &indices, uint2 gid)
Definition scatter.h:15
Definition indexing.h:8
Definition atomic.h:25