MLX
 
Loading...
Searching...
No Matches
atomic.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <metal_atomic>
6#include <metal_stdlib>
7
8using namespace metal;
9
11// Atomic utils
13
14#pragma METAL internals : enable
15template <typename T>
16constexpr constant bool is_metal_atomic = _disjunction<
17 is_same<T, int>,
18 is_same<T, uint>,
19 is_same<T, ulong>,
20 is_same<T, float>>::value;
21
22#pragma METAL internals : disable
23
24template <typename T, typename = void>
25struct mlx_atomic {
26 atomic<uint> val;
27};
28
29template <typename T>
30struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
31 atomic<T> val;
32};
33
35// Native metal atomics
37
38template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
39METAL_FUNC T
40mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
41 return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
42}
43
44template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
45METAL_FUNC void
46mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
47 atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
48}
49
50template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
52 device mlx_atomic<T>* object,
53 T val,
54 size_t offset) {
55 atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
56}
57
58template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
60 device mlx_atomic<T>* object,
61 T val,
62 size_t offset) {
63 atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
64}
65
66template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
68 device mlx_atomic<T>* object,
69 T val,
70 size_t offset) {
71 atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
72}
73
74template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
76 device mlx_atomic<T>* object,
77 T val,
78 size_t offset) {
79 atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
80}
81
82template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
84 device mlx_atomic<T>* object,
85 T val,
86 size_t offset) {
87 atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
88}
89
90template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
92 device mlx_atomic<T>* object,
93 T val,
94 size_t offset) {
95 T expected = mlx_atomic_load_explicit(object, offset);
97 object, &expected, val * expected, offset)) {
98 }
99}
100
101template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
103 device mlx_atomic<T>* object,
104 thread T* expected,
105 T val,
106 size_t offset) {
107 return atomic_compare_exchange_weak_explicit(
108 &(object[offset].val),
109 expected,
110 val,
111 memory_order_relaxed,
112 memory_order_relaxed);
113}
114
115// Specialization for float since it does not atomic_fetch_min_explicit
116template <>
118 device mlx_atomic<float>* object,
119 float val,
120 size_t offset) {
121 float expected = mlx_atomic_load_explicit(object, offset);
122 while (val < expected) {
124 object, &expected, val, offset)) {
125 return;
126 }
127 }
128}
129
130// Specialization for float since it does not atomic_fetch_max_explicit
131template <>
133 device mlx_atomic<float>* object,
134 float val,
135 size_t offset) {
136 float expected = mlx_atomic_load_explicit(object, offset);
137 while (val > expected) {
139 object, &expected, val, offset)) {
140 return;
141 }
142 }
143}
144
146// Custom atomics
148
149namespace {
150
151template <typename T>
152constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
153
154template <typename T>
155union uint_or_packed {
156 T val[packing_size<T>];
157 uint bits;
158};
159
160template <typename T, typename Op>
161struct mlx_atomic_update_helper {
162 uint operator()(uint_or_packed<T> init, T update, size_t elem_offset) {
163 Op op;
164 init.val[elem_offset] = op(update, init.val[elem_offset]);
165 return init.bits;
166 }
167};
168
169template <typename T, typename Op>
170METAL_FUNC void mlx_atomic_update_and_store(
171 device mlx_atomic<T>* object,
172 T update,
173 size_t offset) {
174 size_t pack_offset = offset / packing_size<T>;
175 size_t elem_offset = offset % packing_size<T>;
176
177 mlx_atomic_update_helper<T, Op> helper;
178 uint_or_packed<T> expected;
179 expected.bits =
180 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
181
182 while (Op::condition(update, expected.val[elem_offset]) &&
184 object,
185 &(expected.bits),
186 helper(expected, update, elem_offset),
187 pack_offset)) {
188 }
189}
190
191template <typename T>
192struct __None {
193 static bool condition(T a, T b) {
194#pragma unused(a)
195#pragma unused(b)
196 return true;
197 }
198
199 T operator()(T a, T b) {
200#pragma unused(b)
201 return a;
202 }
203};
204
205template <typename T>
206struct __Add {
207 static bool condition(T a, T b) {
208#pragma unused(a)
209#pragma unused(b)
210 return true;
211 }
212
213 T operator()(T a, T b) {
214 return a + b;
215 }
216};
217
218template <typename T>
219struct __Mul {
220 static bool condition(T a, T b) {
221#pragma unused(a)
222 return b != 0;
223 }
224
225 T operator()(T a, T b) {
226 return a * b;
227 }
228};
229
230template <typename T>
231struct __Max {
232 static bool condition(T a, T b) {
233 return a > b;
234 }
235
236 T operator()(T a, T b) {
237 return max(a, b);
238 }
239};
240
241template <typename T>
242struct __Min {
243 static bool condition(T a, T b) {
244 return a < b;
245 }
246
247 T operator()(T a, T b) {
248 return min(a, b);
249 }
250};
251
252} // namespace
253
254template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
255METAL_FUNC T
256mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
257 size_t pack_offset = offset / sizeof(T);
258 size_t elem_offset = offset % sizeof(T);
259 uint_or_packed<T> packed_val;
260 packed_val.bits =
261 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
262 return packed_val.val[elem_offset];
263}
264
265template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
266METAL_FUNC void
267mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
268 mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
269}
270
271template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
272METAL_FUNC void mlx_atomic_fetch_and_explicit(
273 device mlx_atomic<T>* object,
274 T val,
275 size_t offset) {
276 size_t pack_offset = offset / packing_size<T>;
277 size_t elem_offset = offset % packing_size<T>;
278 uint_or_packed<T> identity;
279 identity.bits = __UINT32_MAX__;
280 identity.val[elem_offset] = val;
281
282 atomic_fetch_and_explicit(
283 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
284}
285
286template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
287METAL_FUNC void mlx_atomic_fetch_or_explicit(
288 device mlx_atomic<T>* object,
289 T val,
290 size_t offset) {
291 size_t pack_offset = offset / packing_size<T>;
292 size_t elem_offset = offset % packing_size<T>;
293 uint_or_packed<T> identity;
294 identity.bits = 0;
295 identity.val[elem_offset] = val;
296
297 atomic_fetch_or_explicit(
298 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
299}
300
301template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
302METAL_FUNC void mlx_atomic_fetch_min_explicit(
303 device mlx_atomic<T>* object,
304 T val,
305 size_t offset) {
306 mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
307}
308
309template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
310METAL_FUNC void mlx_atomic_fetch_max_explicit(
311 device mlx_atomic<T>* object,
312 T val,
313 size_t offset) {
314 mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
315}
316
317template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
318METAL_FUNC void mlx_atomic_fetch_add_explicit(
319 device mlx_atomic<T>* object,
320 T val,
321 size_t offset) {
322 mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
323}
324
325template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
326METAL_FUNC void mlx_atomic_fetch_mul_explicit(
327 device mlx_atomic<T>* object,
328 T val,
329 size_t offset) {
330 mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
331}
332
333template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
335 device mlx_atomic<T>* object,
336 thread uint* expected,
337 uint val,
338 size_t offset) {
339 return atomic_compare_exchange_weak_explicit(
340 &(object[offset].val),
341 expected,
342 val,
343 memory_order_relaxed,
344 memory_order_relaxed);
345}
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:46
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, size_t offset)
Definition atomic.h:132
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, size_t offset)
Definition atomic.h:40
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:51
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:67
constexpr constant bool is_metal_atomic
Definition atomic.h:16
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:83
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:59
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, size_t offset)
Definition atomic.h:117
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:75
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, size_t offset)
Definition atomic.h:102
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:91
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
Definition bf16_math.h:226
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:232
Group init(bool strict=false, const std::string &bk="any")
Initialize the distributed backend and return the group containing all discoverable processes.
Definition atomic.h:25
atomic< uint > val
Definition atomic.h:26