[CUDA] Fix back-end bugs and enable corresponding tests (#2296)

* Fix some cuda back-end bugs and enable corresponding tests

* more fixes

* enable more tests

* format
This commit is contained in:
Awni Hannun
2025-06-16 08:45:40 -07:00
committed by GitHub
parent 4fda5fbdf9
commit c552ff2451
16 changed files with 115 additions and 98 deletions

View File

@@ -45,6 +45,18 @@ struct CastOp<
}
};
template <typename SrcT, typename DstT>
struct CastOp<
SrcT,
DstT,
cuda::std::enable_if_t<cuda::std::is_same_v<SrcT, DstT>>> {
static constexpr bool is_castable = true;
__device__ SrcT operator()(SrcT x) {
return x;
}
};
// Return an iterator that cast the value to DstT using CastOp.
template <typename DstT, typename Iterator>
__host__ __device__ auto make_cast_iterator(Iterator it) {

View File

@@ -5,6 +5,8 @@
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <math_constants.h>
namespace mlx::core::cu {
struct Abs {
@@ -183,21 +185,38 @@ struct Imag {
struct Log {
template <typename T>
__device__ T operator()(T x) {
return log(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto r = log(cuCrealf(Abs{}(x)));
auto i = atan2f(cuCimagf(x), cuCrealf(x));
return {r, i};
} else {
return log(x);
}
}
};
struct Log2 {
template <typename T>
__device__ T operator()(T x) {
return log2(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LN2_F, cuCimagf(y) / CUDART_LN2_F};
} else {
return log2(x);
}
}
};
struct Log10 {
template <typename T>
__device__ T operator()(T x) {
return log10(x);
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto y = Log{}(x);
return {cuCrealf(y) / CUDART_LNT_F, cuCimagf(y) / CUDART_LNT_F};
return y;
} else {
return log10(x);
}
}
};