4[[kernel]] void gather{0}_{3}_{6}(
5 const device {1}* src [[buffer(0)]],
6 device {1}* out [[buffer(1)]],
7 const constant int* src_shape [[buffer(2)]],
8 const constant size_t* src_strides [[buffer(3)]],
9 const constant size_t& src_ndim [[buffer(4)]],
10 const constant int* slice_sizes [[buffer(5)]],
11 const constant int* axes [[buffer(6)]],
12 const constant int* idx_shapes [[buffer(7)]],
13 const constant size_t* idx_strides [[buffer(8)]],
14 const constant int& idx_ndim [[buffer(9)]],
16 uint2 index [[thread_position_in_grid]],
17 uint2 grid_dim [[threads_per_grid]]) {{
18 Indices<{2}, {3}> idxs{{
19 {{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
21 return gather_impl<{1}, {2}, {3}, {6}>(
36[[kernel]] void scatter_1d_index{0}_{4}(
37 const device {1}* updates [[buffer(1)]],
38 device mlx_atomic<{1}>* out [[buffer(2)]],
39 const constant int* out_shape [[buffer(3)]],
40 const constant size_t* out_strides [[buffer(4)]],
41 const constant size_t& out_ndim [[buffer(5)]],
42 const constant int* upd_shape [[buffer(6)]],
43 const constant size_t& upd_ndim [[buffer(7)]],
44 const constant size_t& upd_size [[buffer(8)]],
46 uint2 gid [[thread_position_in_grid]]) {{
47 const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
48 return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
61[[kernel]] void scatter{0}_{4}(
62 const device {1}* updates [[buffer(1)]],
63 device mlx_atomic<{1}>* out [[buffer(2)]],
64 const constant int* upd_shape [[buffer(3)]],
65 const constant size_t* upd_strides [[buffer(4)]],
66 const constant size_t& upd_ndim [[buffer(5)]],
67 const constant size_t& upd_size [[buffer(6)]],
68 const constant int* out_shape [[buffer(7)]],
69 const constant size_t* out_strides [[buffer(8)]],
70 const constant size_t& out_ndim [[buffer(9)]],
71 const constant int* axes [[buffer(10)]],
72 const constant int* idx_shapes [[buffer(11)]],
73 const constant size_t* idx_strides [[buffer(12)]],
74 const constant int& idx_ndim [[buffer(13)]],
76 uint2 gid [[thread_position_in_grid]]) {{
77 Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
79 return scatter_impl<{1}, {2}, {3}, {4}>(
constexpr std::string_view gather_kernels
Definition indexing.h:3
constexpr std::string_view scatter_kernels
Definition indexing.h:35