14#pragma METAL internals : enable 
   20    is_same<T, float>>::value;
 
   22#pragma METAL internals : disable 
   24template <
typename T, 
typename = 
void>
 
   38template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   41  return atomic_load_explicit(&(
object[offset].val), memory_order_relaxed);
 
 
   44template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   47  atomic_store_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   50template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   55  atomic_fetch_and_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   58template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   63  atomic_fetch_or_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   66template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   71  atomic_fetch_min_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   74template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   79  atomic_fetch_max_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   82template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   87  atomic_fetch_add_explicit(&(
object[offset].val), val, memory_order_relaxed);
 
 
   90template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
   97      object, &expected, val * expected, offset)) {
 
 
  101template <
typename T, enable_if_t<is_metal_atomic<T>, 
bool> = true>
 
  107  return atomic_compare_exchange_weak_explicit(
 
  108      &(
object[offset].val),
 
  111      memory_order_relaxed,
 
  112      memory_order_relaxed);
 
 
  122  while (val < expected) {
 
  124            object, &expected, val, offset)) {
 
 
  137  while (val > expected) {
 
  139            object, &expected, val, offset)) {
 
 
  152constexpr constant uint packing_size = 
sizeof(uint) / 
sizeof(T);
 
  155union uint_or_packed {
 
  156  T val[packing_size<T>];
 
  160template <
typename T, 
typename Op>
 
  161struct mlx_atomic_update_helper {
 
  162  uint operator()(uint_or_packed<T> init, T update, 
size_t elem_offset) {
 
  164    init.val[elem_offset] = 
op(update, 
init.val[elem_offset]);
 
  169template <
typename T, 
typename Op>
 
  170METAL_FUNC 
void mlx_atomic_update_and_store(
 
  174  size_t pack_offset = offset / packing_size<T>;
 
  175  size_t elem_offset = offset % packing_size<T>;
 
  177  mlx_atomic_update_helper<T, Op> helper;
 
  178  uint_or_packed<T> expected;
 
  180      atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
 
  182  while (Op::condition(update, expected.val[elem_offset]) &&
 
  186             helper(expected, update, elem_offset),
 
  193  static bool condition(T a, T b) {
 
  199  T operator()(T a, T b) {
 
  207  static bool condition(T a, T b) {
 
  213  T operator()(T a, T b) {
 
  220  static bool condition(T a, T b) {
 
  225  T operator()(T a, T b) {
 
  232  static bool condition(T a, T b) {
 
  236  T operator()(T a, T b) {
 
  243  static bool condition(T a, T b) {
 
  247  T operator()(T a, T b) {
 
  254template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  257  size_t pack_offset = offset / 
sizeof(T);
 
  258  size_t elem_offset = offset % 
sizeof(T);
 
  259  uint_or_packed<T> packed_val;
 
  261      atomic_load_explicit(&(
object[pack_offset].val), memory_order_relaxed);
 
  262  return packed_val.val[elem_offset];
 
  265template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  268  mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
 
  271template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  276  size_t pack_offset = offset / packing_size<T>;
 
  277  size_t elem_offset = offset % packing_size<T>;
 
  282  atomic_fetch_and_explicit(
 
  283      &(
object[pack_offset].val), 
identity.bits, memory_order_relaxed);
 
  286template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  291  size_t pack_offset = offset / packing_size<T>;
 
  292  size_t elem_offset = offset % packing_size<T>;
 
  297  atomic_fetch_or_explicit(
 
  298      &(
object[pack_offset].val), 
identity.bits, memory_order_relaxed);
 
  301template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  306  mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
 
  309template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  314  mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
 
  317template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  322  mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
 
  325template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  330  mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
 
  333template <
typename T, enable_if_t<!is_metal_atomic<T>, 
bool> = true>
 
  336    thread uint* expected,
 
  339  return atomic_compare_exchange_weak_explicit(
 
  340      &(
object[offset].val),
 
  343      memory_order_relaxed,
 
  344      memory_order_relaxed);
 
 
METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:46
 
METAL_FUNC void mlx_atomic_fetch_max_explicit< float >(device mlx_atomic< float > *object, float val, size_t offset)
Definition atomic.h:132
 
METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic< T > *object, size_t offset)
Definition atomic.h:40
 
METAL_FUNC void mlx_atomic_fetch_and_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:51
 
METAL_FUNC void mlx_atomic_fetch_min_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:67
 
constexpr constant bool is_metal_atomic
Definition atomic.h:16
 
METAL_FUNC void mlx_atomic_fetch_add_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:83
 
METAL_FUNC void mlx_atomic_fetch_or_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:59
 
METAL_FUNC void mlx_atomic_fetch_min_explicit< float >(device mlx_atomic< float > *object, float val, size_t offset)
Definition atomic.h:117
 
METAL_FUNC void mlx_atomic_fetch_max_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:75
 
METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(device mlx_atomic< T > *object, thread T *expected, T val, size_t offset)
Definition atomic.h:102
 
METAL_FUNC void mlx_atomic_fetch_mul_explicit(device mlx_atomic< T > *object, T val, size_t offset)
Definition atomic.h:91
 
Op op
Definition binary.h:129
 
array identity(int n, Dtype dtype, StreamOrDevice s={})
Create a square matrix of shape (n,n) of zeros, and ones in the major diagonal.
 
Group init(bool strict=false)
Initialize the distributed backend and return the group containing all discoverable processes.
 
array bits(const std::vector< int > &shape, int width, const std::optional< array > &key=std::nullopt, StreamOrDevice s={})
Generate an array with type uint32 filled with random bits.
 
atomic< uint > val
Definition atomic.h:26