MLX
 
Loading...
Searching...
No Matches
indexing.h
Go to the documentation of this file.
1// Copyright © 2023-2024 Apple Inc.
2
3constexpr std::string_view gather_kernels = R"(
4[[kernel]] void gather{0}_{3}_{6}_{7}(
5 const device {1}* src [[buffer(0)]],
6 device {1}* out [[buffer(1)]],
7 const constant int* src_shape [[buffer(2)]],
8 const constant int64_t* src_strides [[buffer(3)]],
9 const constant size_t& src_ndim [[buffer(4)]],
10 const constant int* slice_sizes [[buffer(5)]],
11 const constant int* axes [[buffer(6)]],
12 const constant int* idx_shapes [[buffer(7)]],
13 const constant int64_t* idx_strides [[buffer(8)]],
14 const constant bool* idx_contigs [[buffer(9)]],
15 const constant int& idx_ndim [[buffer(10)]],
16 {4}
17 uint3 index [[thread_position_in_grid]],
18 uint3 grid_dim [[threads_per_grid]]) {{
19 Indices<{2}, {3}> idxs{{
20 {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
21
22 return gather_impl<{1}, {2}, {3}, {6}, {7}>(
23 src,
24 out,
25 src_shape,
26 src_strides,
27 src_ndim,
28 slice_sizes,
29 axes,
30 idxs,
31 index,
32 grid_dim);
33}}
34)";
35
36constexpr std::string_view scatter_kernels = R"(
37[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
38 const device {1}* updates [[buffer(1)]],
39 device mlx_atomic<{1}>* out [[buffer(2)]],
40 const constant int* upd_shape [[buffer(3)]],
41 const constant int64_t* upd_strides [[buffer(4)]],
42 const constant size_t& upd_ndim [[buffer(5)]],
43 const constant size_t& upd_size [[buffer(6)]],
44 const constant int* out_shape [[buffer(7)]],
45 const constant int64_t* out_strides [[buffer(8)]],
46 const constant size_t& out_ndim [[buffer(9)]],
47 const constant int* axes [[buffer(10)]],
48 const constant int* idx_shapes [[buffer(11)]],
49 const constant int64_t* idx_strides [[buffer(12)]],
50 const constant bool* idx_contigs [[buffer(13)]],
51 const constant int& idx_ndim [[buffer(14)]],
52 const constant size_t& idx_size [[buffer(15)]],
53 {5}
54 uint2 gid [[thread_position_in_grid]]) {{
55 Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
56
57 return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
58 updates,
59 out,
60 upd_shape,
61 upd_strides,
62 upd_ndim,
63 upd_size,
64 out_shape,
65 out_strides,
66 out_ndim,
67 axes,
68 idx_size,
69 idxs,
70 gid);
71}}
72)";
constexpr std::string_view gather_kernels
Definition indexing.h:3
constexpr std::string_view scatter_kernels
Definition indexing.h:36