MLX
Loading...
Searching...
No Matches
indexing.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3constexpr std::string_view gather_kernels = R"(
4[[kernel]] void gather{0}_{3}_{6}(
5 const device {1}* src [[buffer(0)]],
6 device {1}* out [[buffer(1)]],
7 const constant int* src_shape [[buffer(2)]],
8 const constant size_t* src_strides [[buffer(3)]],
9 const constant size_t& src_ndim [[buffer(4)]],
10 const constant int* slice_sizes [[buffer(5)]],
11 const constant int* axes [[buffer(6)]],
12 const constant int* idx_shapes [[buffer(7)]],
13 const constant size_t* idx_strides [[buffer(8)]],
14 const constant int& idx_ndim [[buffer(9)]],
15 {4}
16 uint2 index [[thread_position_in_grid]],
17 uint2 grid_dim [[threads_per_grid]]) {{
18 Indices<{2}, {3}> idxs{{
19 {{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
20
21 return gather_impl<{1}, {2}, {3}, {6}>(
22 src,
23 out,
24 src_shape,
25 src_strides,
26 src_ndim,
27 slice_sizes,
28 axes,
29 idxs,
30 index,
31 grid_dim);
32}}
33)";
34
35constexpr std::string_view scatter_kernels = R"(
36[[kernel]] void scatter_1d_index{0}_{4}(
37 const device {1}* updates [[buffer(1)]],
38 device mlx_atomic<{1}>* out [[buffer(2)]],
39 const constant int* out_shape [[buffer(3)]],
40 const constant size_t* out_strides [[buffer(4)]],
41 const constant size_t& out_ndim [[buffer(5)]],
42 const constant int* upd_shape [[buffer(6)]],
43 const constant size_t& upd_ndim [[buffer(7)]],
44 const constant size_t& upd_size [[buffer(8)]],
45 {5}
46 uint2 gid [[thread_position_in_grid]]) {{
47 const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
48 return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
49 updates,
50 out,
51 out_shape,
52 out_strides,
53 out_ndim,
54 upd_shape,
55 upd_ndim,
56 upd_size,
57 idx_buffers,
58 gid);
59}}
60
61[[kernel]] void scatter{0}_{4}(
62 const device {1}* updates [[buffer(1)]],
63 device mlx_atomic<{1}>* out [[buffer(2)]],
64 const constant int* upd_shape [[buffer(3)]],
65 const constant size_t* upd_strides [[buffer(4)]],
66 const constant size_t& upd_ndim [[buffer(5)]],
67 const constant size_t& upd_size [[buffer(6)]],
68 const constant int* out_shape [[buffer(7)]],
69 const constant size_t* out_strides [[buffer(8)]],
70 const constant size_t& out_ndim [[buffer(9)]],
71 const constant int* axes [[buffer(10)]],
72 const constant int* idx_shapes [[buffer(11)]],
73 const constant size_t* idx_strides [[buffer(12)]],
74 const constant int& idx_ndim [[buffer(13)]],
75 {5}
76 uint2 gid [[thread_position_in_grid]]) {{
77 Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
78
79 return scatter_impl<{1}, {2}, {3}, {4}>(
80 updates,
81 out,
82 upd_shape,
83 upd_strides,
84 upd_ndim,
85 upd_size,
86 out_shape,
87 out_strides,
88 out_ndim,
89 axes,
90 idxs,
91 gid);
92}}
93)";
constexpr std::string_view gather_kernels
Definition indexing.h:3
constexpr std::string_view scatter_kernels
Definition indexing.h:35