// Copyright © 2025 Apple Inc. #pragma once #include "mlx/backend/cuda/device/atomic_ops.cuh" #include "mlx/backend/cuda/device/cast_op.cuh" #include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/reduce/reduce_utils.cuh" namespace mlx::core::cu { // Reduce ops. struct And { __device__ __forceinline__ bool operator()(bool a, bool b) { return a && b; } __device__ void atomic_update(bool* x, bool y) { atomic_reduce(x, y); } }; struct Or { __device__ __forceinline__ bool operator()(bool a, bool b) { return a || b; } __device__ void atomic_update(bool* x, bool y) { atomic_reduce(x, y); } }; struct Sum { template __device__ __forceinline__ T operator()(T a, T b) { return a + b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) { atomic_add(x, y); } __device__ void atomic_update(int* x, int y) { atomic_add(x, y); } __device__ void atomic_update(float* x, float y) { atomic_add(x, y); } }; struct Prod { template __device__ __forceinline__ T operator()(T a, T b) { return a * b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } }; struct Min { template __device__ __forceinline__ T operator()(T a, T b) { if constexpr (is_complex_v) { if (isnan(a.real()) || isnan(a.imag())) { return a; } if (isnan(b.real()) || isnan(b.imag())) { return b; } } else if constexpr (!cuda::std::is_integral_v) { if (isnan(a) || isnan(b)) { return cuda::std::numeric_limits::quiet_NaN(); } } return a < b ? a : b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } }; struct Max { template __device__ __forceinline__ T operator()(T a, T b) { if constexpr (is_complex_v) { if (isnan(a.real()) || isnan(a.imag())) { return a; } if (isnan(b.real()) || isnan(b.imag())) { return b; } } else if constexpr (!cuda::std::is_integral_v) { if (isnan(a) || isnan(b)) { return cuda::std::numeric_limits::quiet_NaN(); } } return a > b ? a : b; } template __device__ void atomic_update(T* x, T y) { atomic_reduce(x, y); } }; // Traits to get the result type of reduce op. template struct ReduceResult; template struct ReduceResult { using type = bool; }; template struct ReduceResult { using type = bool; }; template struct ReduceResult { using type = cuda::std::conditional_t< (cuda::std::is_integral_v && sizeof(T) <= 4), int32_t, T>; }; template struct ReduceResult { using type = cuda::std::conditional_t< (cuda::std::is_integral_v && sizeof(T) <= 4), int32_t, T>; }; template struct ReduceResult { using type = T; }; template struct ReduceResult { using type = T; }; // Traits to get the init value of reduce op. template struct ReduceInit; template struct ReduceInit { static constexpr __host__ __device__ bool value() { return true; } }; template struct ReduceInit { static constexpr __host__ __device__ bool value() { return false; } }; template struct ReduceInit { static constexpr __host__ __device__ auto value() { if constexpr (is_complex_v) { return T{0, 0}; } else { return cast_to::type>(0); } } }; template struct ReduceInit { static constexpr __host__ __device__ auto value() { if constexpr (is_complex_v) { return T{1, 0}; } else { return cast_to::type>(1); } } }; template struct ReduceInit { static constexpr __host__ __device__ T value() { return Limits::max(); } }; template struct ReduceInit { static constexpr __host__ __device__ T value() { return Limits::min(); } }; } // namespace mlx::core::cu