14#pragma METAL internals : enable
20 is_same<T, float>>::value;
22#pragma METAL internals : disable
24template <
typename T,
typename =
void>
38template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
41 return atomic_load_explicit(&(
object[offset].val), memory_order_relaxed);
44template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
47 atomic_store_explicit(&(
object[offset].val), val, memory_order_relaxed);
50template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
55 atomic_fetch_and_explicit(&(
object[offset].val), val, memory_order_relaxed);
58template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
63 atomic_fetch_or_explicit(&(
object[offset].val), val, memory_order_relaxed);
66template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
71 atomic_fetch_min_explicit(&(
object[offset].val), val, memory_order_relaxed);
74template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
79 atomic_fetch_max_explicit(&(
object[offset].val), val, memory_order_relaxed);
82template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
87 atomic_fetch_add_explicit(&(
object[offset].val), val, memory_order_relaxed);
90template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
97 object, &expected, val * expected, offset)) {
101template <
typename T, enable_if_t<is_metal_atomic<T>,
bool> = true>
107 return atomic_compare_exchange_weak_explicit(
108 &(
object[offset].val),
111 memory_order_relaxed,
112 memory_order_relaxed);
122 while (val < expected) {
124 object, &expected, val, offset)) {
137 while (val > expected) {
139 object, &expected, val, offset)) {
152constexpr constant uint packing_size =
sizeof(uint) /
sizeof(T);
155union uint_or_packed {
156 T val[packing_size<T>];
160template <
typename T,
typename Op>
161struct mlx_atomic_update_helper {
162 uint operator()(uint_or_packed<T> init, T update,
size_t elem_offset) {
164 init.val[elem_offset] =
op(update,
init.val[elem_offset]);
169template <
typename T,
typename Op>
170METAL_FUNC
void mlx_atomic_update_and_store(
174 size_t pack_offset = offset / packing_size<T>;
175 size_t elem_offset = offset % packing_size<T>;
177 mlx_atomic_update_helper<T, Op> helper;
178 uint_or_packed<T> expected;
180 atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
182 while (Op::condition(update, expected.val[elem_offset]) &&
186 helper(expected, update, elem_offset),
193 static bool condition(T a, T b) {
199 T operator()(T a, T b) {
207 static bool condition(T a, T b) {
213 T operator()(T a, T b) {
220 static bool condition(T a, T b) {
225 T operator()(T a, T b) {
232 static bool condition(T a, T b) {
236 T operator()(T a, T b) {
243 static bool condition(T a, T b) {
247 T operator()(T a, T b) {
254template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
257 size_t pack_offset = offset /
sizeof(T);
258 size_t elem_offset = offset %
sizeof(T);
259 uint_or_packed<T> packed_val;
261 atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
262 return packed_val.val[elem_offset];
265template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
268 mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
271template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
276 size_t pack_offset = offset / packing_size<T>;
277 size_t elem_offset = offset % packing_size<T>;
282 atomic_fetch_and_explicit(
283 &(
object[pack_offset].val),
identity.bits, memory_order_relaxed);
286template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
291 size_t pack_offset = offset / packing_size<T>;
292 size_t elem_offset = offset % packing_size<T>;
297 atomic_fetch_or_explicit(
298 &(
object[pack_offset].val),
identity.bits, memory_order_relaxed);
301template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
306 mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
309template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
314 mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
317template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
322 mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
325template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
330 mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
333template <
typename T, enable_if_t<!is_metal_atomic<T>,
bool> = true>
336 thread uint* expected,
339 return atomic_compare_exchange_weak_explicit(
340 &(
object[offset].val),
343 memory_order_relaxed,
344 memory_order_relaxed);
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:46
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, size_t offset)
Definition atomic.h:132
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, size_t offset)
Definition atomic.h:40
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:51
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:67
constexpr constant bool is_metal_atomic
Definition atomic.h:16
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:83
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:59
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, size_t offset)
Definition atomic.h:117
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:75
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, size_t offset)
Definition atomic.h:102
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:91
Op op
Definition binary.h:129
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
Group init(bool strict=false)
Initialize the distributed backend and return the group containing all discoverable processes.
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
atomic< uint > val
Definition atomic.h:26