5 device {1}* out [[buffer(0)]],
6 uint tid [[thread_position_in_grid]]) {{
7 out[tid] = {2}<{1}>::init;
12template [[host_name("all_{0}")]] [[kernel]] void
13all_reduce<{1}, {2}, {3}<{2}>>(
14 const device {1}* in [[buffer(0)]],
15 device mlx_atomic<{2}>* out [[buffer(1)]],
16 const device size_t& in_size [[buffer(2)]],
17 uint gid [[thread_position_in_grid]],
18 uint lid [[thread_position_in_threadgroup]],
19 uint grid_size [[threads_per_grid]],
20 uint simd_per_group [[simdgroups_per_threadgroup]],
21 uint simd_lane_id [[thread_index_in_simdgroup]],
22 uint simd_group_id [[simdgroup_index_in_threadgroup]]);
23template [[host_name("colGeneral_{0}")]] [[kernel]] void
24col_reduce_general<{1}, {2}, {3}<{2}>>(
25 const device {1}* in [[buffer(0)]],
26 device mlx_atomic<{2}>* out [[buffer(1)]],
27 const constant size_t& reduction_size [[buffer(2)]],
28 const constant size_t& reduction_stride [[buffer(3)]],
29 const constant size_t& out_size [[buffer(4)]],
30 const constant int* shape [[buffer(5)]],
31 const constant size_t* strides [[buffer(6)]],
32 const constant int& ndim [[buffer(7)]],
33 threadgroup {2}* local_data [[threadgroup(0)]],
34 uint3 tid [[threadgroup_position_in_grid]],
35 uint3 lid [[thread_position_in_threadgroup]],
36 uint3 lsize [[threads_per_threadgroup]]);
37template [[host_name("colSmall_{0}")]] [[kernel]] void
38col_reduce_small<{1}, {2}, {3}<{2}>>(
39 const device {1}* in [[buffer(0)]],
40 device {2}* out [[buffer(1)]],
41 const constant size_t& reduction_size [[buffer(2)]],
42 const constant size_t& reduction_stride [[buffer(3)]],
43 const constant size_t& out_size [[buffer(4)]],
44 const constant int* shape [[buffer(5)]],
45 const constant size_t* strides [[buffer(6)]],
46 const constant int& ndim [[buffer(7)]],
47 const constant size_t& non_col_reductions [[buffer(8)]],
48 const constant int* non_col_shapes [[buffer(9)]],
49 const constant size_t* non_col_strides [[buffer(10)]],
50 const constant int& non_col_ndim [[buffer(11)]],
51 uint tid [[thread_position_in_grid]]);
52template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
53row_reduce_general_small<{1}, {2}, {3}<{2}>>(
54 const device {1}* in [[buffer(0)]],
55 device {2}* out [[buffer(1)]],
56 const constant size_t& reduction_size [[buffer(2)]],
57 const constant size_t& out_size [[buffer(3)]],
58 const constant size_t& non_row_reductions [[buffer(4)]],
59 const constant int* shape [[buffer(5)]],
60 const constant size_t* strides [[buffer(6)]],
61 const constant int& ndim [[buffer(7)]],
62 uint lid [[thread_position_in_grid]]);
63template [[host_name("rowGeneralMed_{0}")]] [[kernel]] void
64row_reduce_general_med<{1}, {2}, {3}<{2}>>(
65 const device {1}* in [[buffer(0)]],
66 device {2}* out [[buffer(1)]],
67 const constant size_t& reduction_size [[buffer(2)]],
68 const constant size_t& out_size [[buffer(3)]],
69 const constant size_t& non_row_reductions [[buffer(4)]],
70 const constant int* shape [[buffer(5)]],
71 const constant size_t* strides [[buffer(6)]],
72 const constant int& ndim [[buffer(7)]],
73 uint tid [[threadgroup_position_in_grid]],
74 uint simd_lane_id [[thread_index_in_simdgroup]],
75 uint simd_per_group [[dispatch_simdgroups_per_threadgroup]],
76 uint simd_group_id [[simdgroup_index_in_threadgroup]]);
77template [[host_name("rowGeneral_{0}")]] [[kernel]] void
78row_reduce_general<{1}, {2}, {3}<{2}>>(
79 const device {1}* in [[buffer(0)]],
80 device mlx_atomic<{2}>* out [[buffer(1)]],
81 const constant size_t& reduction_size [[buffer(2)]],
82 const constant size_t& out_size [[buffer(3)]],
83 const constant size_t& non_row_reductions [[buffer(4)]],
84 const constant int* shape [[buffer(5)]],
85 const constant size_t* strides [[buffer(6)]],
86 const constant int& ndim [[buffer(7)]],
87 uint3 lid [[thread_position_in_threadgroup]],
88 uint3 lsize [[threads_per_threadgroup]],
89 uint3 tid [[threadgroup_position_in_grid]],
90 uint simd_lane_id [[thread_index_in_simdgroup]],
91 uint simd_per_group [[simdgroups_per_threadgroup]],
92 uint simd_group_id [[simdgroup_index_in_threadgroup]]);
96template [[host_name("allNoAtomics_{0}")]] [[kernel]] void
97all_reduce_no_atomics<{1}, {2}, {3}<{2}>>(
98 const device {1}* in [[buffer(0)]],
99 device {2}* out [[buffer(1)]],
100 const device size_t& in_size [[buffer(2)]],
101 uint gid [[thread_position_in_grid]],
102 uint lid [[thread_position_in_threadgroup]],
103 uint grid_size [[threads_per_grid]],
104 uint simd_per_group [[simdgroups_per_threadgroup]],
105 uint simd_lane_id [[thread_index_in_simdgroup]],
106 uint simd_group_id [[simdgroup_index_in_threadgroup]],
107 uint thread_group_id [[threadgroup_position_in_grid]]);
109template [[host_name("colGeneralNoAtomics_{0}")]] [[kernel]] void
110 col_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
111 const device {1}* in [[buffer(0)]],
112 device {2}* out [[buffer(1)]],
113 const constant size_t& reduction_size [[buffer(2)]],
114 const constant size_t& reduction_stride [[buffer(3)]],
115 const constant size_t& out_size [[buffer(4)]],
116 const constant int* shape [[buffer(5)]],
117 const constant size_t* strides [[buffer(6)]],
118 const constant int& ndim [[buffer(7)]],
119 threadgroup {2}* local_data [[threadgroup(0)]],
120 uint3 tid [[threadgroup_position_in_grid]],
121 uint3 lid [[thread_position_in_threadgroup]],
122 uint3 gid [[thread_position_in_grid]],
123 uint3 lsize [[threads_per_threadgroup]],
124 uint3 gsize [[threads_per_grid]]);
125template [[host_name("colSmall_{0}")]] [[kernel]] void
126col_reduce_small<{1}, {2}, {3}<{2}>>(
127 const device {1}* in [[buffer(0)]],
128 device {2}* out [[buffer(1)]],
129 const constant size_t& reduction_size [[buffer(2)]],
130 const constant size_t& reduction_stride [[buffer(3)]],
131 const constant size_t& out_size [[buffer(4)]],
132 const constant int* shape [[buffer(5)]],
133 const constant size_t* strides [[buffer(6)]],
134 const constant int& ndim [[buffer(7)]],
135 const constant size_t& non_col_reductions [[buffer(8)]],
136 const constant int* non_col_shapes [[buffer(9)]],
137 const constant size_t* non_col_strides [[buffer(10)]],
138 const constant int& non_col_ndim [[buffer(11)]],
139 uint tid [[thread_position_in_grid]]);
140template [[host_name("rowGeneralSmall_{0}")]] [[kernel]] void
141row_reduce_general_small<{1}, {2}, {3}<{2}>>(
142 const device {1}* in [[buffer(0)]],
143 device {2}* out [[buffer(1)]],
144 const constant size_t& reduction_size [[buffer(2)]],
145 const constant size_t& out_size [[buffer(3)]],
146 const constant size_t& non_row_reductions [[buffer(4)]],
147 const constant int* shape [[buffer(5)]],
148 const constant size_t* strides [[buffer(6)]],
149 const constant int& ndim [[buffer(7)]],
150 uint lid [[thread_position_in_grid]]);
151template [[host_name("rowGeneralNoAtomics_{0}")]] [[kernel]] void
152row_reduce_general_no_atomics<{1}, {2}, {3}<{2}>>(
153 const device {1}* in [[buffer(0)]],
154 device {2}* out [[buffer(1)]],
155 const constant size_t& reduction_size [[buffer(2)]],
156 const constant size_t& out_size [[buffer(3)]],
157 const constant size_t& non_row_reductions [[buffer(4)]],
158 const constant int* shape [[buffer(5)]],
159 const constant size_t* strides [[buffer(6)]],
160 const constant int& ndim [[buffer(7)]],
161 uint3 lid [[thread_position_in_threadgroup]],
162 uint3 lsize [[threads_per_threadgroup]],
163 uint3 gsize [[threads_per_grid]],
164 uint3 tid [[threadgroup_position_in_grid]],
165 uint simd_lane_id [[thread_index_in_simdgroup]],
166 uint simd_per_group [[simdgroups_per_threadgroup]],
167 uint simd_group_id [[simdgroup_index_in_threadgroup]]);