MLX
Loading...
Searching...
No Matches
binary.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3constexpr std::string_view binary_kernels = R"(
4template [[host_name("ss{0}")]] [[kernel]]
5void binary_ss<{1}, {2}, {3}>(
6 device const {1}* a,
7 device const {1}* b,
8 device {2}* c,
9 uint index [[thread_position_in_grid]]);
10template [[host_name("vs{0}")]] [[kernel]]
11void binary_vs<{1}, {2}, {3}>(
12 device const {1}* a,
13 device const {1}* b,
14 device {2}* c,
15 uint index [[thread_position_in_grid]]);
16template [[host_name("sv{0}")]] [[kernel]]
17void binary_sv<{1}, {2}, {3}>(
18 device const {1}* a,
19 device const {1}* b,
20 device {2}* c,
21 uint index [[thread_position_in_grid]]);
22template [[host_name("vv{0}")]] [[kernel]]
23void binary_vv<{1}, {2}, {3}>(
24 device const {1}* a,
25 device const {1}* b,
26 device {2}* c,
27 uint index [[thread_position_in_grid]]);
28template [[host_name("g4{0}")]] [[kernel]] void
29binary_g_nd<{1}, {2}, {3}, 4>(
30 device const {1}* a,
31 device const {1}* b,
32 device {2}* c,
33 constant const int shape[4],
34 constant const size_t a_strides[4],
35 constant const size_t b_strides[4],
36 uint3 index [[thread_position_in_grid]],
37 uint3 grid_dim [[threads_per_grid]]);
38template [[host_name("g5{0}")]] [[kernel]] void
39binary_g_nd<{1}, {2}, {3}, 5>(
40 device const {1}* a,
41 device const {1}* b,
42 device {2}* c,
43 constant const int shape[5],
44 constant const size_t a_strides[5],
45 constant const size_t b_strides[5],
46 uint3 index [[thread_position_in_grid]],
47 uint3 grid_dim [[threads_per_grid]]);
48
49template [[host_name("g1{0}")]] [[kernel]] void
50binary_g_nd1<{1}, {2}, {3}>(
51 device const {1}* a,
52 device const {1}* b,
53 device {2}* c,
54 constant const size_t& a_stride,
55 constant const size_t& b_stride,
56 uint index [[thread_position_in_grid]]);
57template [[host_name("g2{0}")]] [[kernel]] void
58binary_g_nd2<{1}, {2}, {3}>(
59 device const {1}* a,
60 device const {1}* b,
61 device {2}* c,
62 constant const size_t a_strides[2],
63 constant const size_t b_strides[2],
64 uint2 index [[thread_position_in_grid]],
65 uint2 grid_dim [[threads_per_grid]]);
66template [[host_name("g3{0}")]] [[kernel]] void
67binary_g_nd3<{1}, {2}, {3}>(
68 device const {1}* a,
69 device const {1}* b,
70 device {2}* c,
71 constant const size_t a_strides[3],
72 constant const size_t b_strides[3],
73 uint3 index [[thread_position_in_grid]],
74 uint3 grid_dim [[threads_per_grid]]);
75
76template [[host_name("gn{0}")]] [[kernel]]
77void binary_g<{1}, {2}, {3}>(
78 device const {1}* a,
79 device const {1}* b,
80 device {2}* c,
81 constant const int* shape,
82 constant const size_t* a_strides,
83 constant const size_t* b_strides,
84 constant const int& ndim,
85 uint3 index [[thread_position_in_grid]],
86 uint3 grid_dim [[threads_per_grid]]);
87)";
constexpr std::string_view binary_kernels
Definition binary.h:3