15#pragma METAL internals : enable
21 is_same<T, float>>::value;
23#pragma METAL internals : disable
25template <
typename T,
typename =
void>
39template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
42 return atomic_load_explicit(&(
object[offset].val), memory_order_relaxed);
45template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
48 atomic_store_explicit(&(
object[offset].val), val, memory_order_relaxed);
51template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
56 atomic_fetch_and_explicit(&(
object[offset].val), val, memory_order_relaxed);
59template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
62 atomic_fetch_or_explicit(&(
object[offset].val), val, memory_order_relaxed);
65template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
70 atomic_fetch_min_explicit(&(
object[offset].val), val, memory_order_relaxed);
73template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
78 atomic_fetch_max_explicit(&(
object[offset].val), val, memory_order_relaxed);
81template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
86 atomic_fetch_add_explicit(&(
object[offset].val), val, memory_order_relaxed);
89template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
96 object, &expected, val * expected, offset)) {
100template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
106 return atomic_compare_exchange_weak_explicit(
107 &(
object[offset].val),
110 memory_order_relaxed,
111 memory_order_relaxed);
121 while (val < expected) {
123 object, &expected, val, offset)) {
136 while (val > expected) {
138 object, &expected, val, offset)) {
151constexpr constant uint packing_size =
sizeof(uint) /
sizeof(T);
154union uint_or_packed {
155 T val[packing_size<T>];
159template <
typename T,
typename Op>
160struct mlx_atomic_update_helper {
161 uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
163 init.val[elem_offset] =
op(update, init.val[elem_offset]);
168template <
typename T,
typename Op>
169METAL_FUNC
void mlx_atomic_update_and_store(
173 uint pack_offset = offset / packing_size<T>;
174 uint elem_offset = offset % packing_size<T>;
176 mlx_atomic_update_helper<T, Op> helper;
177 uint_or_packed<T> expected;
179 atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
181 while (Op::condition(update, expected.val[elem_offset]) &&
185 helper(expected, update, elem_offset),
192 static bool condition(T a, T b) {
198 T operator()(T a, T b) {
206 static bool condition(T a, T b) {
212 T operator()(T a, T b) {
219 static bool condition(T a, T b) {
224 T operator()(T a, T b) {
231 static bool condition(T a, T b) {
235 T operator()(T a, T b) {
242 static bool condition(T a, T b) {
246 T operator()(T a, T b) {
253template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
256 uint pack_offset = offset /
sizeof(T);
257 uint elem_offset = offset %
sizeof(T);
258 uint_or_packed<T> packed_val;
260 atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
261 return packed_val.val[elem_offset];
264template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
267 mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
270template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
275 uint pack_offset = offset / packing_size<T>;
276 uint elem_offset = offset % packing_size<T>;
281 atomic_fetch_and_explicit(
282 &(
object[pack_offset].val),
identity.bits, memory_order_relaxed);
285template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
288 uint pack_offset = offset / packing_size<T>;
289 uint elem_offset = offset % packing_size<T>;
294 atomic_fetch_or_explicit(
295 &(
object[pack_offset].val),
identity.bits, memory_order_relaxed);
298template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
303 mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
306template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
311 mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
314template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
319 mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
322template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
327 mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
330template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
333 thread uint* expected,
336 return atomic_compare_exchange_weak_explicit(
337 &(
object[offset].val),
340 memory_order_relaxed,
341 memory_order_relaxed);
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:82
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:131
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:52
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, uint offset)
Definition atomic.h:41
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:47
constexpr constant bool is_metal_atomic
Definition atomic.h:17
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:61
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:116
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:74
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:66
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:90
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, uint offset)
Definition atomic.h:101
Op op
Definition binary.h:139
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
atomic< uint > val
Definition atomic.h:27