MLX
Loading...
Searching...
No Matches
indexing.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3constexpr std::string_view gather_kernels = R"(
4[[kernel]] void gather{0}_{3}_{6}(
5 const device {1}* src [[buffer(0)]],
6 device {1}* out [[buffer(1)]],
7 const constant int* src_shape [[buffer(2)]],
8 const constant size_t* src_strides [[buffer(3)]],
9 const constant size_t& src_ndim [[buffer(4)]],
10 const constant int* slice_sizes [[buffer(5)]],
11 const constant int* axes [[buffer(6)]],
12 const constant int* idx_shapes [[buffer(7)]],
13 const constant size_t* idx_strides [[buffer(8)]],
14 const constant int& idx_ndim [[buffer(9)]],
15 {4}
16 uint2 index [[thread_position_in_grid]],
17 uint2 grid_dim [[threads_per_grid]]) {{
18 Indices<{2}, {3}> idxs{{
19 {{ {5} }}, idx_shapes, idx_strides, idx_ndim}};
20
21 return gather_impl<{1}, {2}, {3}, {6}>(
22 src,
23 out,
24 src_shape,
25 src_strides,
26 src_ndim,
27 slice_sizes,
28 axes,
29 idxs,
30 index,
31 grid_dim);
32}}
33)";
34
35constexpr std::string_view scatter_kernels = R"(
36[[kernel]] void scatter_1d_index{0}_{4}(
37 const device {1}* updates [[buffer(1)]],
38 device mlx_atomic<{1}>* out [[buffer(2)]],
39 const constant int* out_shape [[buffer(3)]],
40 const constant size_t* out_strides [[buffer(4)]],
41 const constant size_t& upd_size [[buffer(5)]],
42 {5}
43 uint2 gid [[thread_position_in_grid]]) {{
44 const array<const device {2}*, {4}> idx_buffers = {{ {6} }};
45 return scatter_1d_index_impl<{1}, {2}, {3}, {4}>(
46 updates, out, out_shape, out_strides, upd_size, idx_buffers, gid);
47}}
48
49[[kernel]] void scatter{0}_{4}(
50 const device {1}* updates [[buffer(1)]],
51 device mlx_atomic<{1}>* out [[buffer(2)]],
52 const constant int* upd_shape [[buffer(3)]],
53 const constant size_t* upd_strides [[buffer(4)]],
54 const constant size_t& upd_ndim [[buffer(5)]],
55 const constant size_t& upd_size [[buffer(6)]],
56 const constant int* out_shape [[buffer(7)]],
57 const constant size_t* out_strides [[buffer(8)]],
58 const constant size_t& out_ndim [[buffer(9)]],
59 const constant int* axes [[buffer(10)]],
60 const constant int* idx_shapes [[buffer(11)]],
61 const constant size_t* idx_strides [[buffer(12)]],
62 const constant int& idx_ndim [[buffer(13)]],
63 {5}
64 uint2 gid [[thread_position_in_grid]]) {{
65 Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_ndim}};
66
67 return scatter_impl<{1}, {2}, {3}, {4}>(
68 updates,
69 out,
70 upd_shape,
71 upd_strides,
72 upd_ndim,
73 upd_size,
74 out_shape,
75 out_strides,
76 out_ndim,
77 axes,
78 idxs,
79 gid);
80}}
81)";
constexpr std::string_view gather_kernels
Definition indexing.h:3
constexpr std::string_view scatter_kernels
Definition indexing.h:35