MLX
Loading...
Searching...
No Matches
binary_two.h
Go to the documentation of this file.
1// Copyright © 2024 Apple Inc.
2
3constexpr std::string_view binary_two_kernels = R"(
4template [[host_name("ss{0}")]] [[kernel]]
5void binary_ss<{1}, {2}, {3}>(
6 device const {1}* a,
7 device const {1}* b,
8 device {2}* c,
9 device {2}* d,
10 uint index [[thread_position_in_grid]]);
11template [[host_name("vs{0}")]] [[kernel]]
12void binary_vs<{1}, {2}, {3}>(
13 device const {1}* a,
14 device const {1}* b,
15 device {2}* c,
16 device {2}* d,
17 uint index [[thread_position_in_grid]]);
18template [[host_name("sv{0}")]] [[kernel]]
19void binary_sv<{1}, {2}, {3}>(
20 device const {1}* a,
21 device const {1}* b,
22 device {2}* c,
23 device {2}* d,
24 uint index [[thread_position_in_grid]]);
25template [[host_name("vv{0}")]] [[kernel]]
26void binary_vv<{1}, {2}, {3}>(
27 device const {1}* a,
28 device const {1}* b,
29 device {2}* c,
30 device {2}* d,
31 uint index [[thread_position_in_grid]]);
32
33template [[host_name("g4{0}")]] [[kernel]] void
34binary_g_nd<{1}, {2}, {3}, 4>(
35 device const {1}* a,
36 device const {1}* b,
37 device {2}* c,
38 device {2}* d,
39 constant const int shape[4],
40 constant const size_t a_strides[4],
41 constant const size_t b_strides[4],
42 uint3 index [[thread_position_in_grid]],
43 uint3 grid_dim [[threads_per_grid]]);
44template [[host_name("g5{0}")]] [[kernel]] void
45binary_g_nd<{1}, {2}, {3}, 5>(
46 device const {1}* a,
47 device const {1}* b,
48 device {2}* c,
49 device {2}* d,
50 constant const int shape[5],
51 constant const size_t a_strides[5],
52 constant const size_t b_strides[5],
53 uint3 index [[thread_position_in_grid]],
54 uint3 grid_dim [[threads_per_grid]]);
55
56template [[host_name("g1{0}")]] [[kernel]] void
57binary_g_nd1<{1}, {2}, {3}>(
58 device const {1}* a,
59 device const {1}* b,
60 device {2}* c,
61 device {2}* d,
62 constant const size_t& a_stride,
63 constant const size_t& b_stride,
64 uint index [[thread_position_in_grid]]);
65template [[host_name("g2{0}")]] [[kernel]] void
66binary_g_nd2<{1}, {2}, {3}>(
67 device const {1}* a,
68 device const {1}* b,
69 device {2}* c,
70 device {2}* d,
71 constant const size_t a_strides[2],
72 constant const size_t b_strides[2],
73 uint2 index [[thread_position_in_grid]],
74 uint2 grid_dim [[threads_per_grid]]);
75template [[host_name("g3{0}")]] [[kernel]] void
76binary_g_nd3<{1}, {2}, {3}>(
77 device const {1}* a,
78 device const {1}* b,
79 device {2}* c,
80 device {2}* d,
81 constant const size_t a_strides[3],
82 constant const size_t b_strides[3],
83 uint3 index [[thread_position_in_grid]],
84 uint3 grid_dim [[threads_per_grid]]);
85
86template [[host_name("gn{0}")]] [[kernel]]
87void binary_g<{1}, {2}, {3}>(
88 device const {1}* a,
89 device const {1}* b,
90 device {2}* c,
91 device {2}* d,
92 constant const int* shape,
93 constant const size_t* a_strides,
94 constant const size_t* b_strides,
95 constant const int& ndim,
96 uint3 index [[thread_position_in_grid]],
97 uint3 grid_dim [[threads_per_grid]]);
98)";
constexpr std::string_view binary_two_kernels
Definition binary_two.h:3