// Copyright © 2025 Apple Inc. #include #include "mlx/backend/common/reduce.h" #include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" namespace mlx::core { template void dispatch_reduce_ndim(int ndim, F&& f) { if (ndim == 1) { f(std::integral_constant{}); } else if (ndim == 2) { f(std::integral_constant{}); } else { f(std::integral_constant{}); } } template void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { if (reduce_type == Reduce::ReduceType::And) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Or) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Sum) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Prod) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Max) { f(type_identity{}); } else if (reduce_type == Reduce::ReduceType::Min) { f(type_identity{}); } else { throw std::invalid_argument("Unknown reduce type."); } } void all_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type); void row_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan); void col_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type, const std::vector& axes, const ReductionPlan& plan); void init_reduce( cu::CommandEncoder& encoder, const array& in, array& out, Reduce::ReduceType reduce_type); } // namespace mlx::core