// Copyright © 2023 Apple Inc. #pragma once #include #include using namespace metal; /////////////////////////////////////////////////////////////////////////////// // Atomic utils /////////////////////////////////////////////////////////////////////////////// #pragma METAL internals : enable template constexpr constant bool is_metal_atomic = _disjunction< is_same, is_same, is_same, is_same>::value; #pragma METAL internals : disable template struct mlx_atomic { atomic val; }; template struct mlx_atomic>> { atomic val; }; /////////////////////////////////////////////////////////////////////////////// // Native metal atomics /////////////////////////////////////////////////////////////////////////////// template , bool> = true> METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_and_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_or_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_add_explicit( device mlx_atomic* object, T val, size_t offset) { atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_mul_explicit( device mlx_atomic* object, T val, size_t offset) { T expected = mlx_atomic_load_explicit(object, offset); while (!mlx_atomic_compare_exchange_weak_explicit( object, &expected, val * expected, offset)) { } } template , bool> = true> METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread T* expected, T val, size_t offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, val, memory_order_relaxed, memory_order_relaxed); } // Specialization for float since it does not atomic_fetch_min_explicit template <> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, float val, size_t offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val < expected) { if (mlx_atomic_compare_exchange_weak_explicit( object, &expected, val, offset)) { return; } } } // Specialization for float since it does not atomic_fetch_max_explicit template <> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, float val, size_t offset) { float expected = mlx_atomic_load_explicit(object, offset); while (val > expected) { if (mlx_atomic_compare_exchange_weak_explicit( object, &expected, val, offset)) { return; } } } /////////////////////////////////////////////////////////////////////////////// // Custom atomics /////////////////////////////////////////////////////////////////////////////// namespace { template constexpr constant uint packing_size = sizeof(uint) / sizeof(T); template union uint_or_packed { T val[packing_size]; uint bits; }; template struct mlx_atomic_update_helper { uint operator()(uint_or_packed init, T update, size_t elem_offset) { Op op; init.val[elem_offset] = op(update, init.val[elem_offset]); return init.bits; } }; template METAL_FUNC void mlx_atomic_update_and_store( device mlx_atomic* object, T update, size_t offset) { size_t pack_offset = offset / packing_size; size_t elem_offset = offset % packing_size; mlx_atomic_update_helper helper; uint_or_packed expected; expected.bits = atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); while (Op::condition(update, expected.val[elem_offset]) && !mlx_atomic_compare_exchange_weak_explicit( object, &(expected.bits), helper(expected, update, elem_offset), pack_offset)) { } } template struct __None { static bool condition(T a, T b) { #pragma unused(a) #pragma unused(b) return true; } T operator()(T a, T b) { #pragma unused(b) return a; } }; template struct __Add { static bool condition(T a, T b) { #pragma unused(a) #pragma unused(b) return true; } T operator()(T a, T b) { return a + b; } }; template struct __Mul { static bool condition(T a, T b) { #pragma unused(a) return b != 0; } T operator()(T a, T b) { return a * b; } }; template struct __Max { static bool condition(T a, T b) { return a > b; } T operator()(T a, T b) { return max(a, b); } }; template struct __Min { static bool condition(T a, T b) { return a < b; } T operator()(T a, T b) { return min(a, b); } }; } // namespace template , bool> = true> METAL_FUNC T mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { size_t pack_offset = offset / sizeof(T); size_t elem_offset = offset % sizeof(T); uint_or_packed packed_val; packed_val.bits = atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); return packed_val.val[elem_offset]; } template , bool> = true> METAL_FUNC void mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_and_explicit( device mlx_atomic* object, T val, size_t offset) { size_t pack_offset = offset / packing_size; size_t elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = __UINT32_MAX__; identity.val[elem_offset] = val; atomic_fetch_and_explicit( &(object[pack_offset].val), identity.bits, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_or_explicit( device mlx_atomic* object, T val, size_t offset) { size_t pack_offset = offset / packing_size; size_t elem_offset = offset % packing_size; uint_or_packed identity; identity.bits = 0; identity.val[elem_offset] = val; atomic_fetch_or_explicit( &(object[pack_offset].val), identity.bits, memory_order_relaxed); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_min_explicit( device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_max_explicit( device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_add_explicit( device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC void mlx_atomic_fetch_mul_explicit( device mlx_atomic* object, T val, size_t offset) { mlx_atomic_update_and_store>(object, val, offset); } template , bool> = true> METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( device mlx_atomic* object, thread uint* expected, uint val, size_t offset) { return atomic_compare_exchange_weak_explicit( &(object[offset].val), expected, val, memory_order_relaxed, memory_order_relaxed); }