[CUDA] Fix reductions (#2314)

This commit is contained in:
Angelos Katharopoulos
2025-06-27 12:59:20 -07:00
committed by GitHub
parent 2c11d10f8d
commit 772f471ff2
16 changed files with 862 additions and 419 deletions

View File

@@ -3,48 +3,89 @@
#pragma once
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_utils.cuh"
namespace mlx::core::cu {
// Reduce ops.
struct And {
__device__ bool operator()(bool a, bool b) {
__device__ __forceinline__ bool operator()(bool a, bool b) {
return a && b;
}
__device__ void atomic_update(bool* x, bool y) {
atomic_reduce<bool, And>(x, y);
}
};
struct Or {
__device__ bool operator()(bool a, bool b) {
__device__ __forceinline__ bool operator()(bool a, bool b) {
return a || b;
}
__device__ void atomic_update(bool* x, bool y) {
atomic_reduce<bool, Or>(x, y);
}
};
struct Sum {
template <typename T>
__device__ T operator()(T a, T b) {
__device__ __forceinline__ T operator()(T a, T b) {
return a + b;
}
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Sum>(x, y);
}
__device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
atomicAdd(x, y);
}
__device__ void atomic_update(int* x, int y) {
atomicAdd(x, y);
}
__device__ void atomic_update(float* x, float y) {
atomicAdd(x, y);
}
};
struct Prod {
template <typename T>
__device__ T operator()(T a, T b) {
__device__ __forceinline__ T operator()(T a, T b) {
return a * b;
}
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Prod>(x, y);
}
};
struct Min {
template <typename T>
__device__ T operator()(T a, T b) {
__device__ __forceinline__ T operator()(T a, T b) {
return a < b ? a : b;
}
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Min>(x, y);
}
};
struct Max {
template <typename T>
__device__ T operator()(T a, T b) {
__device__ __forceinline__ T operator()(T a, T b) {
return a > b ? a : b;
}
template <typename T>
__device__ void atomic_update(T* x, T y) {
atomic_reduce<T, Max>(x, y);
}
};
// Traits to get the result type of reduce op.
@@ -120,7 +161,7 @@ template <typename T>
struct ReduceInit<Prod, T> {
static constexpr __host__ __device__ auto value() {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return T{1, 1};
return T{1, 0};
} else {
return typename ReduceResult<Prod, T>::type{1};
}