MLX
Loading...
Searching...
No Matches
atomic.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <metal_atomic>
6#include <metal_stdlib>
7
8using namespace metal;
9
11// Atomic utils
13
14#pragma METAL internals : enable
15template <typename T>
16constexpr constant bool is_metal_atomic = _disjunction<
17 is_same<T, int>,
18 is_same<T, uint>,
19 is_same<T, ulong>,
20 is_same<T, float>>::value;
21
22#pragma METAL internals : disable
23
24template <typename T, typename = void>
25struct mlx_atomic {
26 atomic<uint> val;
27};
28
29template <typename T>
30struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
31 atomic<T> val;
32};
33
35// Native metal atomics
37
38template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
39METAL_FUNC T
40mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
41 return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
42}
43
44template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
45METAL_FUNC void
46mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
47 atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
48}
49
50template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
52 device mlx_atomic<T>* object,
53 T val,
54 uint offset) {
55 atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
56}
57
58template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
59METAL_FUNC void
60mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
61 atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
62}
63
64template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
66 device mlx_atomic<T>* object,
67 T val,
68 uint offset) {
69 atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
70}
71
72template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
74 device mlx_atomic<T>* object,
75 T val,
76 uint offset) {
77 atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
78}
79
80template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
82 device mlx_atomic<T>* object,
83 T val,
84 uint offset) {
85 atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
86}
87
88template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
90 device mlx_atomic<T>* object,
91 T val,
92 uint offset) {
93 T expected = mlx_atomic_load_explicit(object, offset);
95 object, &expected, val * expected, offset)) {
96 }
97}
98
99template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
101 device mlx_atomic<T>* object,
102 thread T* expected,
103 T val,
104 uint offset) {
105 return atomic_compare_exchange_weak_explicit(
106 &(object[offset].val),
107 expected,
108 val,
109 memory_order_relaxed,
110 memory_order_relaxed);
111}
112
113// Specialization for float since it does not atomic_fetch_min_explicit
114template <>
116 device mlx_atomic<float>* object,
117 float val,
118 uint offset) {
119 float expected = mlx_atomic_load_explicit(object, offset);
120 while (val < expected) {
122 object, &expected, val, offset)) {
123 return;
124 }
125 }
126}
127
128// Specialization for float since it does not atomic_fetch_max_explicit
129template <>
131 device mlx_atomic<float>* object,
132 float val,
133 uint offset) {
134 float expected = mlx_atomic_load_explicit(object, offset);
135 while (val > expected) {
137 object, &expected, val, offset)) {
138 return;
139 }
140 }
141}
142
144// Custom atomics
146
147namespace {
148
149template <typename T>
150constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
151
152template <typename T>
153union uint_or_packed {
154 T val[packing_size<T>];
155 uint bits;
156};
157
158template <typename T, typename Op>
159struct mlx_atomic_update_helper {
160 uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
161 Op op;
162 init.val[elem_offset] = op(update, init.val[elem_offset]);
163 return init.bits;
164 }
165};
166
167template <typename T, typename Op>
168METAL_FUNC void mlx_atomic_update_and_store(
169 device mlx_atomic<T>* object,
170 T update,
171 uint offset) {
172 uint pack_offset = offset / packing_size<T>;
173 uint elem_offset = offset % packing_size<T>;
174
175 mlx_atomic_update_helper<T, Op> helper;
176 uint_or_packed<T> expected;
177 expected.bits =
178 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
179
180 while (Op::condition(update, expected.val[elem_offset]) &&
182 object,
183 &(expected.bits),
184 helper(expected, update, elem_offset),
185 pack_offset)) {
186 }
187}
188
189template <typename T>
190struct __None {
191 static bool condition(T a, T b) {
192#pragma unused(a)
193#pragma unused(b)
194 return true;
195 }
196
197 T operator()(T a, T b) {
198#pragma unused(b)
199 return a;
200 }
201};
202
203template <typename T>
204struct __Add {
205 static bool condition(T a, T b) {
206#pragma unused(a)
207#pragma unused(b)
208 return true;
209 }
210
211 T operator()(T a, T b) {
212 return a + b;
213 }
214};
215
216template <typename T>
217struct __Mul {
218 static bool condition(T a, T b) {
219#pragma unused(a)
220 return b != 0;
221 }
222
223 T operator()(T a, T b) {
224 return a * b;
225 }
226};
227
228template <typename T>
229struct __Max {
230 static bool condition(T a, T b) {
231 return a > b;
232 }
233
234 T operator()(T a, T b) {
235 return max(a, b);
236 }
237};
238
239template <typename T>
240struct __Min {
241 static bool condition(T a, T b) {
242 return a < b;
243 }
244
245 T operator()(T a, T b) {
246 return min(a, b);
247 }
248};
249
250} // namespace
251
252template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
253METAL_FUNC T
254mlx_atomic_load_explicit(device mlx_atomic<T>* object, uint offset) {
255 uint pack_offset = offset / sizeof(T);
256 uint elem_offset = offset % sizeof(T);
257 uint_or_packed<T> packed_val;
258 packed_val.bits =
259 atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
260 return packed_val.val[elem_offset];
261}
262
263template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
264METAL_FUNC void
265mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, uint offset) {
266 mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
267}
268
269template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
270METAL_FUNC void mlx_atomic_fetch_and_explicit(
271 device mlx_atomic<T>* object,
272 T val,
273 uint offset) {
274 uint pack_offset = offset / packing_size<T>;
275 uint elem_offset = offset % packing_size<T>;
276 uint_or_packed<T> identity;
277 identity.bits = __UINT32_MAX__;
278 identity.val[elem_offset] = val;
279
280 atomic_fetch_and_explicit(
281 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
282}
283
284template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
285METAL_FUNC void
286mlx_atomic_fetch_or_explicit(device mlx_atomic<T>* object, T val, uint offset) {
287 uint pack_offset = offset / packing_size<T>;
288 uint elem_offset = offset % packing_size<T>;
289 uint_or_packed<T> identity;
290 identity.bits = 0;
291 identity.val[elem_offset] = val;
292
293 atomic_fetch_or_explicit(
294 &(object[pack_offset].val), identity.bits, memory_order_relaxed);
295}
296
297template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
298METAL_FUNC void mlx_atomic_fetch_min_explicit(
299 device mlx_atomic<T>* object,
300 T val,
301 uint offset) {
302 mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
303}
304
305template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
306METAL_FUNC void mlx_atomic_fetch_max_explicit(
307 device mlx_atomic<T>* object,
308 T val,
309 uint offset) {
310 mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
311}
312
313template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
314METAL_FUNC void mlx_atomic_fetch_add_explicit(
315 device mlx_atomic<T>* object,
316 T val,
317 uint offset) {
318 mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
319}
320
321template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
322METAL_FUNC void mlx_atomic_fetch_mul_explicit(
323 device mlx_atomic<T>* object,
324 T val,
325 uint offset) {
326 mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
327}
328
329template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
331 device mlx_atomic<T>* object,
332 thread uint* expected,
333 uint val,
334 uint offset) {
335 return atomic_compare_exchange_weak_explicit(
336 &(object[offset].val),
337 expected,
338 val,
339 memory_order_relaxed,
340 memory_order_relaxed);
341}
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:81
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:130
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:51
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, uint offset)
Definition atomic.h:40
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:46
constexpr constant bool is_metal_atomic
Definition atomic.h:16
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:60
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:115
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:73
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:65
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:89
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, uint offset)
Definition atomic.h:100
Op op
Definition binary.h:141
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
Definition bf16.h:265
METAL_FUNC bfloat16_t min(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
METAL_FUNC bfloat16_t max(bfloat16_t x, bfloat16_t y)
Definition bf16_math.h:234
Group init(bool strict=false)
Initialize the distributed backend and return the group containing all discoverable processes.
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
Definition atomic.h:25
atomic< uint > val
Definition atomic.h:26