14#pragma METAL internals : enable
20 is_same<T, float>>::value;
22#pragma METAL internals : disable
24template <
typename T,
typename =
void>
38template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
41 return atomic_load_explicit(&(
object[offset].val), memory_order_relaxed);
44template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
47 atomic_store_explicit(&(
object[offset].val), val, memory_order_relaxed);
50template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
55 atomic_fetch_and_explicit(&(
object[offset].val), val, memory_order_relaxed);
58template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
61 atomic_fetch_or_explicit(&(
object[offset].val), val, memory_order_relaxed);
64template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
69 atomic_fetch_min_explicit(&(
object[offset].val), val, memory_order_relaxed);
72template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
77 atomic_fetch_max_explicit(&(
object[offset].val), val, memory_order_relaxed);
80template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
85 atomic_fetch_add_explicit(&(
object[offset].val), val, memory_order_relaxed);
88template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
95 object, &expected, val * expected, offset)) {
99template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
105 return atomic_compare_exchange_weak_explicit(
106 &(
object[offset].val),
109 memory_order_relaxed,
110 memory_order_relaxed);
120 while (val < expected) {
122 object, &expected, val, offset)) {
135 while (val > expected) {
137 object, &expected, val, offset)) {
150constexpr constant uint packing_size =
sizeof(uint) /
sizeof(T);
153union uint_or_packed {
154 T val[packing_size<T>];
158template <
typename T,
typename Op>
159struct mlx_atomic_update_helper {
160 uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
162 init.val[elem_offset] =
op(update,
init.val[elem_offset]);
167template <
typename T,
typename Op>
168METAL_FUNC
void mlx_atomic_update_and_store(
172 uint pack_offset = offset / packing_size<T>;
173 uint elem_offset = offset % packing_size<T>;
175 mlx_atomic_update_helper<T, Op> helper;
176 uint_or_packed<T> expected;
178 atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
180 while (Op::condition(update, expected.val[elem_offset]) &&
184 helper(expected, update, elem_offset),
191 static bool condition(T a, T b) {
197 T operator()(T a, T b) {
205 static bool condition(T a, T b) {
211 T operator()(T a, T b) {
218 static bool condition(T a, T b) {
223 T operator()(T a, T b) {
230 static bool condition(T a, T b) {
234 T operator()(T a, T b) {
241 static bool condition(T a, T b) {
245 T operator()(T a, T b) {
252template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
255 uint pack_offset = offset /
sizeof(T);
256 uint elem_offset = offset %
sizeof(T);
257 uint_or_packed<T> packed_val;
259 atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
260 return packed_val.val[elem_offset];
263template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
266 mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
269template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
274 uint pack_offset = offset / packing_size<T>;
275 uint elem_offset = offset % packing_size<T>;
280 atomic_fetch_and_explicit(
281 &(
object[pack_offset].val),
identity.bits, memory_order_relaxed);
284template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
287 uint pack_offset = offset / packing_size<T>;
288 uint elem_offset = offset % packing_size<T>;
293 atomic_fetch_or_explicit(
294 &(
object[pack_offset].val),
identity.bits, memory_order_relaxed);
297template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
302 mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
305template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
310 mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
313template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
318 mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
321template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
326 mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
329template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
332 thread uint* expected,
335 return atomic_compare_exchange_weak_explicit(
336 &(
object[offset].val),
339 memory_order_relaxed,
340 memory_order_relaxed);
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:81
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:130
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:51
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, uint offset)
Definition atomic.h:40
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:46
constexpr constant bool is_metal_atomic
Definition atomic.h:16
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:60
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:115
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:73
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:65
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:89
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, uint offset)
Definition atomic.h:100
Op op
Definition binary.h:141
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
Group init(bool strict=false)
Initialize the distributed backend and return the group containing all discoverable processes.
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
atomic< uint > val
Definition atomic.h:26