MLX
Loading...
Searching...
No Matches
atomic.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <metal_atomic>
6#include <metal_stdlib>
8
9using namespace metal;
10
12// Atomic utils
14
15#pragma METAL internals : enable
16template <typename T>
17constexpr constant bool is_metal_atomic = _disjunction<
18 is_same<T, int>,
19 is_same<T, uint>,
20 is_same<T, ulong>,
21 is_same<T, float>>::value;
22
23#pragma METAL internals : disable
24
25template <typename T, typename = void>
26struct mlx_atomic {
27 atomic<uint> val;
28};
29
30template <typename T>
31struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
32 atomic<T> val;
33};
34
36// Native metal atomics
38
39template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
40METAL_FUNC T
41mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
42 return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
43}
44
45template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
46METAL_FUNC void
47mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
48 atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
49}
50
51template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
53 device mlx_atomic<T>* object,
54 T val,
55 uint offset) {
56 atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
57}
58
59template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
60METAL_FUNC void
61mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
62 atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
63}
64
65template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
67 device mlx_atomic<T>* object,
68 T val,
69 uint offset) {
70 atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
71}
72
73template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
75 device mlx_atomic<T>* object,
76 T val,
77 uint offset) {
78 atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
79}
80
81template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
83 device mlx_atomic<T>* object,
84 T val,
85 uint offset) {
86 atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
87}
88
89template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
91 device mlx_atomic<T>* object,
92 T val,
93 uint offset) {
94 T expected = mlx_atomic_load_explicit(object, offset);
96 object, &expected, val * expected, offset)) {
97 }
98}
99
100template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
102 device mlx_atomic<T>* object,
103 thread T* expected,
104 T val,
105 uint offset) {
106 return atomic_compare_exchange_weak_explicit(
107 &(object[offset].val),
108 expected,
109 val,
110 memory_order_relaxed,
111 memory_order_relaxed);
112}
113
114// Specialization for float since it does not atomic_fetch_min_explicit
115template <>
117 device mlx_atomic<float>* object,
118 float val,
119 uint offset) {
120 float expected = mlx_atomic_load_explicit(object, offset);
121 while (val < expected) {
123 object, &expected, val, offset)) {
124 return;
125 }
126 }
127}
128
129// Specialization for float since it does not atomic_fetch_max_explicit
130template <>
132 device mlx_atomic<float>* object,
133 float val,
134 uint offset) {
135 float expected = mlx_atomic_load_explicit(object, offset);
136 while (val > expected) {
138 object, &expected, val, offset)) {
139 return;
140 }
141 }
142}
143
145// Custom atomics
147
148namespace {
149
150template <typename T>
151constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
152
153template <typename T>
154union uint_or_packed {
155 T val[packing_size<T>];
156 uint bits;
157};
158
159template <typename T, typename Op>
160struct mlx_atomic_update_helper {
161 uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
162 Op op;
163 init.val[elem_offset] = op(update, init.val[elem_offset]);
164 return init.bits;
165 }
166};
167
168template <typename T, typename Op>
169METAL_FUNC void mlx_atomic_update_and_store(
170 device mlx_atomic<T>* object,
171 T update,
172 uint offset) {
173 uint pack_offset = offset / packing_size<T>;
174 uint elem_offset = offset % packing_size<T>;
175
176 mlx_atomic_update_helper<T, Op> helper;
177 uint_or_packed<T> expected;
178 expected.bits =
179 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
180
181 while (Op::condition(update, expected.val[elem_offset]) &&
183 object,
184 &(expected.bits),
185 helper(expected, update, elem_offset),
186 pack_offset)) {
187 }
188}
189
190template <typename T>
191struct __None {
192 static bool condition(T a, T b) {
193#pragma unused(a)
194#pragma unused(b)
195 return true;
196 }
197
198 T operator()(T a, T b) {
199#pragma unused(b)
200 return a;
201 }
202};
203
204template <typename T>
205struct __Add {
206 static bool condition(T a, T b) {
207#pragma unused(a)
208#pragma unused(b)
209 return true;
210 }
211
212 T operator()(T a, T b) {
213 return a + b;
214 }
215};
216
217template <typename T>
218struct __Mul {
219 static bool condition(T a, T b) {
220#pragma unused(a)
221 return b != 0;
222 }
223
224 T operator()(T a, T b) {
225 return a * b;
226 }
227};
228
229template <typename T>
230struct __Max {
231 static bool condition(T a, T b) {
232 return a > b;
233 }
234
235 T operator()(T a, T b) {
236 return max(a, b);
237 }
238};
239
240template <typename T>
241struct __Min {
242 static bool condition(T a, T b) {
243 return a < b;
244 }
245
246 T operator()(T a, T b) {
247 return min(a, b);
248 }
249};
250
251} // namespace
252
253template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
254METAL_FUNC T
255mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
256 uint pack_offset = offset / sizeof(T);
257 uint elem_offset = offset % sizeof(T);
258 uint_or_packed<T> packed_val;
259 packed_val.bits =
260 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
261 return packed_val.val[elem_offset];
262}
263
264template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
265METAL_FUNC void
266mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
267 mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
268}
269
270template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
271METAL_FUNC void mlx_atomic_fetch_and_explicit(
272 device mlx_atomic<T>* object,
273 T val,
274 uint offset) {
275 uint pack_offset = offset / packing_size<T>;
276 uint elem_offset = offset % packing_size<T>;
277 uint_or_packed<T> identity;
278 identity.bits = __UINT32_MAX__;
279 identity.val[elem_offset] = val;
280
281 atomic_fetch_and_explicit(
282 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
283}
284
285template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
286METAL_FUNC void
287mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
288 uint pack_offset = offset / packing_size<T>;
289 uint elem_offset = offset % packing_size<T>;
290 uint_or_packed<T> identity;
291 identity.bits = 0;
292 identity.val[elem_offset] = val;
293
294 atomic_fetch_or_explicit(
295 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
296}
297
298template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
299METAL_FUNC void mlx_atomic_fetch_min_explicit(
300 device mlx_atomic<T>* object,
301 T val,
302 uint offset) {
303 mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
304}
305
306template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
307METAL_FUNC void mlx_atomic_fetch_max_explicit(
308 device mlx_atomic<T>* object,
309 T val,
310 uint offset) {
311 mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
312}
313
314template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
315METAL_FUNC void mlx_atomic_fetch_add_explicit(
316 device mlx_atomic<T>* object,
317 T val,
318 uint offset) {
319 mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
320}
321
322template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
323METAL_FUNC void mlx_atomic_fetch_mul_explicit(
324 device mlx_atomic<T>* object,
325 T val,
326 uint offset) {
327 mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
328}
329
330template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
332 device mlx_atomic<T>* object,
333 thread uint* expected,
334 uint val,
335 uint offset) {
336 return atomic_compare_exchange_weak_explicit(
337 &(object[offset].val),
338 expected,
339 val,
340 memory_order_relaxed,
341 memory_order_relaxed);
342}
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:82
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:131
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:52
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, uint offset)
Definition atomic.h:41
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:47
constexpr constant bool is_metal_atomic
Definition atomic.h:17
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:61
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:116
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:74
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:66
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:90
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, uint offset)
Definition atomic.h:101
Op op
Definition binary.h:139
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
Definition bf16.h:265
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
Definition atomic.h:26
atomic< uint > val
Definition atomic.h:27