[CUDA] Implement Scan kernel (#2347)

* Contiguous scan

* Strided scan

* Enable tests

* Fix failing logaddexp test

* Use cexpf in Metal
This commit is contained in:
Cheng
2025-07-11 08:54:12 +09:00
committed by GitHub
parent b6eec20260
commit 8347575ba1
13 changed files with 815 additions and 64 deletions

View File

@@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/cuda/device/cexpf.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -150,8 +152,7 @@ struct Exp {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto m = exp(cuCrealf(x));
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
return detail::cexpf(x);
} else {
return exp(x);
}
@@ -228,8 +229,25 @@ struct Log10 {
struct Log1p {
template <typename T>
__device__ T operator()(T x) {
return log1p(x);
__device__ T operator()(T z) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float x = cuCrealf(z);
float y = cuCimagf(z);
float zabs = cuCrealf(Abs{}(z));
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
float z0 = hypotf(x + 1, y);
return {logf(z0), theta};
}
} else {
return log1p(z);
}
}
};
@@ -387,19 +405,19 @@ struct Tanh {
}
};
__device__ cuComplex ArcCos::operator()(cuComplex x) {
inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
auto i = cuComplex{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcSin::operator()(cuComplex x) {
inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcTan::operator()(cuComplex x) {
inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto ix = i * x;
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));