MLX
 
Loading...
Searching...
No Matches
scatter_axis.h
Go to the documentation of this file.
1// Copyright © 2025 Apple Inc.
2
3#pragma once
4
5template <
6 typename T,
7 typename IdxT,
8 typename LocT,
9 typename Op,
10 bool UpdC,
11 bool IdxC>
12[[kernel]] void scatter_axis(
13 const device T* upd [[buffer(0)]],
14 const device IdxT* indices [[buffer(1)]],
15 device mlx_atomic<T>* out [[buffer(2)]],
16 const constant int* shape [[buffer(3)]],
17 const constant int64_t* upd_strides [[buffer(4)]],
18 const constant int64_t* idx_strides [[buffer(5)]],
19 const constant size_t& ndim [[buffer(6)]],
20 const constant int& axis [[buffer(7)]],
21 const constant int& out_axis_size [[buffer(8)]],
22 const constant size_t& upd_ax_stride [[buffer(9)]],
23 const constant size_t& idx_ax_stride [[buffer(10)]],
24 uint3 index [[thread_position_in_grid]],
25 uint3 grid_dim [[threads_per_grid]]) {
26 Op op;
27
28 LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
29
30 LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
31 if (IdxC) {
32 idx_loc += elem_idx * grid_dim.y + index.x;
33 } else {
34 idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
35 }
36
37 auto idx_val = indices[idx_loc];
38 if (is_signed_v<IdxT>) {
39 idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val;
40 }
41
42 LocT upd_idx = index.y * static_cast<LocT>(upd_ax_stride);
43 if (UpdC) {
44 upd_idx += elem_idx * grid_dim.y + index.x;
45 } else {
46 upd_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, upd_strides, ndim);
47 }
48
49 LocT out_idx = elem_idx * static_cast<LocT>(out_axis_size) +
50 idx_val * grid_dim.x + index.x;
51 op.atomic_update(out, upd[upd_idx], out_idx);
52}
METAL_FUNC IdxT elem_to_loc(IdxT elem, constant const int *shape, constant const int64_t *strides, int ndim)
Definition utils.h:93
void scatter_axis(const device T *upd, const device IdxT *indices, device mlx_atomic< T > *out, const constant int *shape, const constant int64_t *upd_strides, const constant int64_t *idx_strides, const constant size_t &ndim, const constant int &axis, const constant int &out_axis_size, const constant size_t &upd_ax_stride, const constant size_t &idx_ax_stride, uint3 index, uint3 grid_dim)
Definition scatter_axis.h:12
Definition atomic.h:25