[CUDA] Fix reductions (#2314)

This commit is contained in:
Angelos Katharopoulos
2025-06-27 12:59:20 -07:00
committed by GitHub
parent 2c11d10f8d
commit 772f471ff2
16 changed files with 862 additions and 419 deletions

View File

@@ -47,13 +47,11 @@ namespace mlx::core {
throw std::invalid_argument("Unknown reduce type."); \
}
void segmented_reduce(
void all_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type,
const std::vector<int>& axes,
const ReductionPlan& plan);
Reduce::ReduceType reduce_type);
void row_reduce(
cu::CommandEncoder& encoder,
@@ -71,4 +69,10 @@ void col_reduce(
const std::vector<int>& axes,
const ReductionPlan& plan);
void init_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type);
} // namespace mlx::core