mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
[CUDA] Implement Scan kernel (#2347)
* Contiguous scan * Strided scan * Enable tests * Fix failing logaddexp test * Use cexpf in Metal
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/backend/cuda/device/cexpf.cuh"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
@@ -150,8 +152,7 @@ struct Exp {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
auto m = exp(cuCrealf(x));
|
||||
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
|
||||
return detail::cexpf(x);
|
||||
} else {
|
||||
return exp(x);
|
||||
}
|
||||
@@ -228,8 +229,25 @@ struct Log10 {
|
||||
|
||||
struct Log1p {
|
||||
template <typename T>
|
||||
__device__ T operator()(T x) {
|
||||
return log1p(x);
|
||||
__device__ T operator()(T z) {
|
||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||
float x = cuCrealf(z);
|
||||
float y = cuCimagf(z);
|
||||
float zabs = cuCrealf(Abs{}(z));
|
||||
float theta = atan2f(y, x + 1);
|
||||
if (zabs < 0.5f) {
|
||||
float r = x * (2 + x) + y * y;
|
||||
if (r == 0) { // handle underflow
|
||||
return {x, theta};
|
||||
}
|
||||
return {0.5f * log1pf(r), theta};
|
||||
} else {
|
||||
float z0 = hypotf(x + 1, y);
|
||||
return {logf(z0), theta};
|
||||
}
|
||||
} else {
|
||||
return log1p(z);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -387,19 +405,19 @@ struct Tanh {
|
||||
}
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcCos::operator()(cuComplex x) {
|
||||
inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0, 1.0};
|
||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcSin::operator()(cuComplex x) {
|
||||
inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
|
||||
return {cuCimagf(y), -cuCrealf(y)};
|
||||
};
|
||||
|
||||
__device__ cuComplex ArcTan::operator()(cuComplex x) {
|
||||
inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
|
||||
auto i = cuComplex{0.0f, 1.0f};
|
||||
auto ix = i * x;
|
||||
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
|
||||
|
||||
Reference in New Issue
Block a user