14#pragma METAL internals : enable 
   20    is_same<T, float>>::value;
 
   22#pragma METAL internals : disable 
   24template <
typename T, 
typename = 
void>
 
   38template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   41  return atomic_load_explicit(&(
object[offset].val), memory_order_relaxed);
 
 
   44template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   47  atomic_store_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   50template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   55  atomic_fetch_and_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   58template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   61  atomic_fetch_or_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   64template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   69  atomic_fetch_min_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   72template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   77  atomic_fetch_max_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   80template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   85  atomic_fetch_add_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   88template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   95      object, &expected, val * expected, offset)) {
 
 
   99template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
  105  return atomic_compare_exchange_weak_explicit(
 
  106      &(
object[offset].val),
 
  109      memory_order_relaxed,
 
  110      memory_order_relaxed);
 
 
  120  while (val < expected) {
 
  122            object, &expected, val, offset)) {
 
 
  135  while (val > expected) {
 
  137            object, &expected, val, offset)) {
 
 
  150constexpr constant uint packing_size = 
sizeof(uint) / 
sizeof(T);
 
  153union uint_or_packed {
 
  154  T val[packing_size<T>];
 
  158template <
typename T, 
typename Op>
 
  159struct mlx_atomic_update_helper {
 
  160  uint operator()(uint_or_packed<T> init, T update, uint elem_offset) {
 
  162    init.val[elem_offset] = 
op(update, 
init.val[elem_offset]);
 
  167template <
typename T, 
typename Op>
 
  168METAL_FUNC 
void mlx_atomic_update_and_store(
 
  172  uint pack_offset = offset / packing_size<T>;
 
  173  uint elem_offset = offset % packing_size<T>;
 
  175  mlx_atomic_update_helper<T, Op> helper;
 
  176  uint_or_packed<T> expected;
 
  178      atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
 
  180  while (Op::condition(update, expected.val[elem_offset]) &&
 
  184             helper(expected, update, elem_offset),
 
  191  static bool condition(T a, T b) {
 
  197  T operator()(T a, T b) {
 
  205  static bool condition(T a, T b) {
 
  211  T operator()(T a, T b) {
 
  218  static bool condition(T a, T b) {
 
  223  T operator()(T a, T b) {
 
  230  static bool condition(T a, T b) {
 
  234  T operator()(T a, T b) {
 
  241  static bool condition(T a, T b) {
 
  245  T operator()(T a, T b) {
 
  252template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  255  uint pack_offset = offset / 
sizeof(T);
 
  256  uint elem_offset = offset % 
sizeof(T);
 
  257  uint_or_packed<T> packed_val;
 
  259      atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
 
  260  return packed_val.val[elem_offset];
 
  263template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  266  mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
 
  269template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  274  uint pack_offset = offset / packing_size<T>;
 
  275  uint elem_offset = offset % packing_size<T>;
 
  280  atomic_fetch_and_explicit(
 
  281      &(
object[pack_offset].val), 
identity.bits, memory_order_relaxed);
 
  284template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  287  uint pack_offset = offset / packing_size<T>;
 
  288  uint elem_offset = offset % packing_size<T>;
 
  293  atomic_fetch_or_explicit(
 
  294      &(
object[pack_offset].val), 
identity.bits, memory_order_relaxed);
 
  297template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  302  mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
 
  305template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  310  mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
 
  313template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  318  mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
 
  321template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  326  mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
 
  329template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  332    thread uint* expected,
 
  335  return atomic_compare_exchange_weak_explicit(
 
  336      &(
object[offset].val),
 
  339      memory_order_relaxed,
 
  340      memory_order_relaxed);
 
 
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:81
 
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:130
 
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:51
 
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, uint offset)
Definition atomic.h:40
 
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:46
 
constexpr constant bool is_metal_atomic
Definition atomic.h:16
 
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:60
 
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, uint offset)
Definition atomic.h:115
 
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:73
 
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:65
 
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, uint offset)
Definition atomic.h:89
 
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, uint offset)
Definition atomic.h:100
 
Op op
Definition binary.h:141
 
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
 
Group init(bool strict=false)
Initialize the distributed backend and return the group containing all discoverable processes.
 
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
 
atomic< uint > val
Definition atomic.h:26