[CUDA] Implement Scan kernel (#2347)

* Contiguous scan

* Strided scan

* Enable tests

* Fix failing logaddexp test

* Use cexpf in Metal
This commit is contained in:
Cheng
2025-07-11 08:54:12 +09:00
committed by GitHub
parent b6eec20260
commit 8347575ba1
13 changed files with 815 additions and 64 deletions

View File

@@ -1,10 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include "mlx/backend/cuda/device/unary_ops.cuh"
#include <cuComplex.h>
#include <cuda/std/array>
namespace mlx::core::cu {
@@ -114,36 +111,38 @@ struct LessEqual {
struct LogAddExp {
template <typename T>
__device__ T operator()(T x, T y) {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
auto max = cuCrealf(x) > cuCrealf(y) ? x : y;
auto min = cuCrealf(x) < cuCrealf(y) ? x : y;
auto min_real = cuCrealf(min);
auto max_real = cuCrealf(max);
if (!isfinite(min_real) && (min_real == max_real)) {
if (min_real < 0) {
return min;
} else {
return Log{}(Exp{}(min) + Exp{}(max));
}
} else {
return Log1p{}(Exp{}(min - max)) + max;
}
} else {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
};
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y;
auto minval = x < y ? x : y;
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
return maxval;
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
cuComplex dexp{
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
};
return maxval + log1p(dexp);
}
};
struct Maximum {

View File

@@ -0,0 +1,138 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2008-2013 NVIDIA Corporation
// Copyright © 2013 Filipe RNC Maia
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
// can not be used in JIT.
#pragma once
#include <cuComplex.h>
#include <cuda/std/cstdint>
namespace mlx::core::cu::detail {
using ieee_float_shape_type = union {
float value;
uint32_t word;
};
inline __device__ void get_float_word(uint32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline __device__ void get_float_word(int32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline __device__ void set_float_word(float& d, uint32_t i) {
ieee_float_shape_type sf_u;
sf_u.word = (i);
(d) = sf_u.value;
}
inline __device__ float frexp_expf(float x, int* expt) {
const uint32_t k = 235;
const float kln2 = 162.88958740F;
float exp_x;
uint32_t hx;
exp_x = expf(x - kln2);
get_float_word(hx, exp_x);
*expt = (hx >> 23) - (0x7f + 127) + k;
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
return exp_x;
}
inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) {
float x, y, exp_x, scale1, scale2;
int ex_expt, half_expt;
x = cuCrealf(z);
y = cuCimagf(z);
exp_x = frexp_expf(x, &ex_expt);
expt += ex_expt;
half_expt = expt / 2;
set_float_word(scale1, (0x7f + half_expt) << 23);
half_expt = expt - half_expt;
set_float_word(scale2, (0x7f + half_expt) << 23);
return cuComplex{
cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2};
}
inline __device__ cuComplex cexpf(const cuComplex& z) {
float x, y, exp_x;
uint32_t hx, hy;
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
x = cuCrealf(z);
y = cuCimagf(z);
get_float_word(hy, y);
hy &= 0x7fffffff;
/* cexp(x + I 0) = exp(x) + I 0 */
if (hy == 0) {
return cuComplex{expf(x), y};
}
get_float_word(hx, x);
/* cexp(0 + I y) = cos(y) + I sin(y) */
if ((hx & 0x7fffffff) == 0) {
return cuComplex{cosf(y), sinf(y)};
}
if (hy >= 0x7f800000) {
if ((hx & 0x7fffffff) != 0x7f800000) {
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
return cuComplex{y - y, y - y};
} else if (hx & 0x80000000) {
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
return cuComplex{0.0, 0.0};
} else {
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
return cuComplex{x, y - y};
}
}
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
/*
* x is between 88.7 and 192, so we must scale to avoid
* overflow in expf(x).
*/
return ldexp_cexpf(z, 0);
} else {
/*
* Cases covered here:
* - x < exp_ovfl and exp(x) won't overflow (common case)
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
* - x = +-Inf (generated by exp())
* - x = NaN (spurious inexact exception from y)
*/
exp_x = expf(x);
return cuComplex{exp_x * cosf(y), exp_x * sinf(y)};
}
}
} // namespace mlx::core::cu::detail

View File

@@ -2,6 +2,8 @@
#pragma once
#include "mlx/backend/cuda/device/cexpf.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -150,8 +152,7 @@ struct Exp {
template <typename T>
__device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto m = exp(cuCrealf(x));
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
return detail::cexpf(x);
} else {
return exp(x);
}
@@ -228,8 +229,25 @@ struct Log10 {
struct Log1p {
template <typename T>
__device__ T operator()(T x) {
return log1p(x);
__device__ T operator()(T z) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float x = cuCrealf(z);
float y = cuCimagf(z);
float zabs = cuCrealf(Abs{}(z));
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
float z0 = hypotf(x + 1, y);
return {logf(z0), theta};
}
} else {
return log1p(z);
}
}
};
@@ -387,19 +405,19 @@ struct Tanh {
}
};
__device__ cuComplex ArcCos::operator()(cuComplex x) {
inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
auto i = cuComplex{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcSin::operator()(cuComplex x) {
inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
return {cuCimagf(y), -cuCrealf(y)};
};
__device__ cuComplex ArcTan::operator()(cuComplex x) {
inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f};
auto ix = i * x;
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));

View File

@@ -359,21 +359,4 @@ struct LoopedElemToLoc<1, false, OffsetT> {
}
};
inline __device__ cuComplex log1p(cuComplex in) {
float x = cuCrealf(in);
float y = cuCimagf(in);
float zabs = sqrt(x * x + y * y);
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
return {log(z0), theta};
}
}
} // namespace mlx::core::cu