// Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/reduce/reduce_utils.cuh" namespace mlx::core::cu { // Reduce ops. struct And { __device__ __forceinline__ bool operator()(bool a, bool b) { return a && b; } __device__ void atomic_update(bool* x, bool y) { atomic_reduce(x, y); } }; struct Or { __device__ __forceinline__ bool operator()(bool a, bool b) { return a || b; } __device__ void atomic_update(bool* x, bool y) { atomic_reduce(x, y); } }; struct Sum { template __device__ __forceinline__ T operator()(T a, T b) { return a + b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { atomicAdd(x, y); } __device__ void atomic_update(int* x, int y) { atomicAdd(x, y); } __device__ void atomic_update(float* x, float y) { atomicAdd(x, y); } }; struct Prod { template __device__ __forceinline__ T operator()(T a, T b) { return a * b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } }; struct Min { template __device__ __forceinline__ T operator()(T a, T b) { return a < b ? a : b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } }; struct Max { template __device__ __forceinline__ T operator()(T a, T b) { return a > b ? a : b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } }; // Traits to get the result type of reduce op. template struct ReduceResult; template struct ReduceResult { using type = bool; }; template struct ReduceResult { using type = bool; }; template struct ReduceResult { using type = cuda::std::conditional_t< (cuda::std::is_integral_v && sizeof(T) <= 4), int32_t, T>; }; template struct ReduceResult { using type = cuda::std::conditional_t< (cuda::std::is_integral_v && sizeof(T) <= 4), int32_t, T>; }; template struct ReduceResult { using type = T; }; template struct ReduceResult { using type = T; }; // Traits to get the init value of reduce op. template struct ReduceInit; template struct ReduceInit { static constexpr __host__ __device__ bool value() { return true; } }; template struct ReduceInit { static constexpr __host__ __device__ bool value() { return false; } }; template struct ReduceInit { static constexpr __host__ __device__ auto value() { if constexpr (cuda::std::is_same_v) { return T{0, 0}; } else { return typename ReduceResult::type{0}; } } }; template struct ReduceInit { static constexpr __host__ __device__ auto value() { if constexpr (cuda::std::is_same_v) { return T{1, 0}; } else { return typename ReduceResult::type{1}; } } }; template struct ReduceInit { static constexpr __host__ __device__ T value() { return Limits::max(); } }; template struct ReduceInit { static constexpr __host__ __device__ T value() { return Limits::min(); } }; } // namespace mlx::core::cu