4[[kernel]] void gather{0}_{3}_{6}(
5 const device {1}* src [[buffer(0)]],
6 device {1}* out [[buffer(1)]],
7 const constant int* src_shape [[buffer(2)]],
8 const constant size_t* src_strides [[buffer(3)]],
9 const constant size_t& src_ndim [[buffer(4)]],
10 const constant int* slice_sizes [[buffer(5)]],
11 const constant int* axes [[buffer(6)]],
12 const constant int* idx_shapes [[buffer(7)]],
13 const constant size_t* idx_strides [[buffer(8)]],
14 const constant int& idx_ndim [[buffer(9)]],
16 uint2 index [[thread_position_in_grid]],
17 uint2 grid_dim [[threads_per_grid]]) {{
18 Indices<{2}, {3}> idxs{{
19 {{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
21 return gather_impl<{1}, {2}, {3}, {6}>(
36[[kernel]] void scatter_1d_index{0}_{4}(
37 const device {1}* updates [[buffer(1)]],
38 device mlx_atomic<{1}>* out [[buffer(2)]],
39 const constant int* out_shape [[buffer(3)]],
40 const constant size_t* out_strides [[buffer(4)]],
41 const constant size_t& upd_size [[buffer(5)]],
43 uint2 gid [[thread_position_in_grid]]) {{
44 const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
45 return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
46 updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
49[[kernel]] void scatter{0}_{4}(
50 const device {1}* updates [[buffer(1)]],
51 device mlx_atomic<{1}>* out [[buffer(2)]],
52 const constant int* upd_shape [[buffer(3)]],
53 const constant size_t* upd_strides [[buffer(4)]],
54 const constant size_t& upd_ndim [[buffer(5)]],
55 const constant size_t& upd_size [[buffer(6)]],
56 const constant int* out_shape [[buffer(7)]],
57 const constant size_t* out_strides [[buffer(8)]],
58 const constant size_t& out_ndim [[buffer(9)]],
59 const constant int* axes [[buffer(10)]],
60 const constant int* idx_shapes [[buffer(11)]],
61 const constant size_t* idx_strides [[buffer(12)]],
62 const constant int& idx_ndim [[buffer(13)]],
64 uint2 gid [[thread_position_in_grid]]) {{
65 Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
67 return scatter_impl<{1}, {2}, {3}, {4}>(
constexpr std::string_view gather_kernels
Definition indexing.h:3
constexpr std::string_view scatter_kernels
Definition indexing.h:35