MLX
Loading...
Searching...
No Matches
scatter.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3#pragma once
4
6
7template <typename T, typename IdxT, typename Op, int NIDX>
8METAL_FUNC void scatter_1d_index_impl(
9 const device T* updates [[buffer(1)]],
10 device mlx_atomic<T>* out [[buffer(2)]],
11 const constant int* out_shape [[buffer(3)]],
12 const constant size_t* out_strides [[buffer(4)]],
13 const constant size_t& out_ndim [[buffer(5)]],
14 const constant int* upd_shape [[buffer(6)]],
15 const constant size_t& upd_ndim [[buffer(7)]],
16 const constant size_t& upd_size [[buffer(8)]],
17 const thread array<const device IdxT*, NIDX>& idx_buffers,
18 uint2 gid [[thread_position_in_grid]]) {
19 Op op;
20
21 size_t out_idx = 0;
22 for (int i = 0; i < NIDX; i++) {
23 auto idx_val = offset_neg_idx(idx_buffers[i][gid.y], out_shape[i]);
24 out_idx += idx_val * out_strides[i];
25 }
26
27 if (upd_ndim > 1) {
28 auto out_offset = elem_to_loc(gid.x, upd_shape + 1, out_strides, out_ndim);
29 out_idx += out_offset;
30 } else {
31 out_idx += gid.x;
32 }
33
34 op.atomic_update(out, updates[gid.y * upd_size + gid.x], out_idx);
35}
36
37template <typename T, typename IdxT, typename Op, int NIDX>
38METAL_FUNC void scatter_impl(
39 const device T* updates [[buffer(1)]],
40 device mlx_atomic<T>* out [[buffer(2)]],
41 const constant int* upd_shape [[buffer(3)]],
42 const constant size_t* upd_strides [[buffer(4)]],
43 const constant size_t& upd_ndim [[buffer(5)]],
44 const constant size_t& upd_size [[buffer(6)]],
45 const constant int* out_shape [[buffer(7)]],
46 const constant size_t* out_strides [[buffer(8)]],
47 const constant size_t& out_ndim [[buffer(9)]],
48 const constant int* axes [[buffer(10)]],
49 const thread Indices<IdxT, NIDX>& indices,
50 uint2 gid [[thread_position_in_grid]]) {
51 Op op;
52 auto ind_idx = gid.y;
53 auto ind_offset = gid.x;
54
55 size_t out_idx = 0;
56 for (int i = 0; i < NIDX; ++i) {
57 auto idx_loc = elem_to_loc(
58 ind_idx,
59 &indices.shapes[indices.ndim * i],
60 &indices.strides[indices.ndim * i],
61 indices.ndim);
62 auto ax = axes[i];
63 auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
64 out_idx += idx_val * out_strides[ax];
65 }
66
67 if (upd_size > 1) {
68 auto out_offset = elem_to_loc(
69 ind_offset, upd_shape + indices.ndim, out_strides, out_ndim);
70 out_idx += out_offset;
71 }
72
73 auto upd_idx =
74 elem_to_loc(gid.y * upd_size + gid.x, upd_shape, upd_strides, upd_ndim);
75 op.atomic_update(out, updates[upd_idx], out_idx);
76}
METAL_FUNC stride_t elem_to_loc(uint elem, constant const int *shape, constant const stride_t *strides, int ndim)
Definition utils.h:87
Op op
Definition binary.h:129
METAL_FUNC size_t offset_neg_idx(IdxT idx, size_t size)
Definition indexing.h:16
METAL_FUNC void scatter_impl(const device T *updates, device mlx_atomic< T > *out, const constant int *upd_shape, const constant size_t *upd_strides, const constant size_t &upd_ndim, const constant size_t &upd_size, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *axes, const thread Indices< IdxT, NIDX > &indices, uint2 gid)
Definition scatter.h:38
METAL_FUNC void scatter_1d_index_impl(const device T *updates, device mlx_atomic< T > *out, const constant int *out_shape, const constant size_t *out_strides, const constant size_t &out_ndim, const constant int *upd_shape, const constant size_t &upd_ndim, const constant size_t &upd_size, const thread array< const device IdxT *, NIDX > &idx_buffers, uint2 gid)
Definition scatter.h:8
Definition indexing.h:8
Definition atomic.h:25