diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index dc4f8e7bb..644786a92 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -1,10 +1,7 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/device/cucomplex_math.cuh" -#include "mlx/backend/cuda/device/fp16_math.cuh" -#include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/device/unary_ops.cuh" -#include #include namespace mlx::core::cu { @@ -114,36 +111,38 @@ struct LessEqual { struct LogAddExp { template __device__ T operator()(T x, T y) { - if (isnan(x) || isnan(y)) { - return cuda::std::numeric_limits::quiet_NaN(); + if constexpr (cuda::std::is_same_v) { + if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || + isnan(cuCimagf(y))) { + return { + cuda::std::numeric_limits::quiet_NaN(), + cuda::std::numeric_limits::quiet_NaN()}; + } + auto max = cuCrealf(x) > cuCrealf(y) ? x : y; + auto min = cuCrealf(x) < cuCrealf(y) ? x : y; + auto min_real = cuCrealf(min); + auto max_real = cuCrealf(max); + if (!isfinite(min_real) && (min_real == max_real)) { + if (min_real < 0) { + return min; + } else { + return Log{}(Exp{}(min) + Exp{}(max)); + } + } else { + return Log1p{}(Exp{}(min - max)) + max; + } + } else { + if (isnan(x) || isnan(y)) { + return cuda::std::numeric_limits::quiet_NaN(); + } + T maxval = max(x, y); + T minval = min(x, y); + return (minval == -cuda::std::numeric_limits::infinity() || + maxval == cuda::std::numeric_limits::infinity()) + ? maxval + : T(float(maxval) + log1p(expf(minval - maxval))); } - T maxval = max(x, y); - T minval = min(x, y); - return (minval == -cuda::std::numeric_limits::infinity() || - maxval == cuda::std::numeric_limits::infinity()) - ? maxval - : T(float(maxval) + log1p(expf(minval - maxval))); }; - - __device__ cuComplex operator()(cuComplex x, cuComplex y) { - if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) || - isnan(cuCimagf(y))) { - return { - cuda::std::numeric_limits::quiet_NaN(), - cuda::std::numeric_limits::quiet_NaN()}; - } - float inf = cuda::std::numeric_limits::infinity(); - auto maxval = x > y ? x : y; - auto minval = x < y ? x : y; - if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf) - return maxval; - float m = exp(cuCrealf(minval) - cuCrealf(maxval)); - cuComplex dexp{ - m * cos(cuCimagf(minval) - cuCimagf(maxval)), - m * sin(cuCimagf(minval) - cuCimagf(maxval)), - }; - return maxval + log1p(dexp); - } }; struct Maximum { diff --git a/mlx/backend/cuda/device/cexpf.cuh b/mlx/backend/cuda/device/cexpf.cuh new file mode 100644 index 000000000..61c94c00f --- /dev/null +++ b/mlx/backend/cuda/device/cexpf.cuh @@ -0,0 +1,138 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2008-2013 NVIDIA Corporation +// Copyright © 2013 Filipe RNC Maia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h + +// TODO: We should use thrust::exp but the thrust header in old CUDA versions +// can not be used in JIT. + +#pragma once + +#include +#include + +namespace mlx::core::cu::detail { + +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; + +inline __device__ void get_float_word(uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline __device__ void get_float_word(int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline __device__ void set_float_word(float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} + +inline __device__ float frexp_expf(float x, int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + + float exp_x; + uint32_t hx; + + exp_x = expf(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} + +inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + + x = cuCrealf(z); + y = cuCimagf(z); + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + + return cuComplex{ + cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2}; +} + +inline __device__ cuComplex cexpf(const cuComplex& z) { + float x, y, exp_x; + uint32_t hx, hy; + + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + + x = cuCrealf(z); + y = cuCimagf(z); + + get_float_word(hy, y); + hy &= 0x7fffffff; + + /* cexp(x + I 0) = exp(x) + I 0 */ + if (hy == 0) { + return cuComplex{expf(x), y}; + } + get_float_word(hx, x); + /* cexp(0 + I y) = cos(y) + I sin(y) */ + if ((hx & 0x7fffffff) == 0) { + return cuComplex{cosf(y), sinf(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ + return cuComplex{y - y, y - y}; + } else if (hx & 0x80000000) { + /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ + return cuComplex{0.0, 0.0}; + } else { + /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ + return cuComplex{x, y - y}; + } + } + + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + /* + * x is between 88.7 and 192, so we must scale to avoid + * overflow in expf(x). + */ + return ldexp_cexpf(z, 0); + } else { + /* + * Cases covered here: + * - x < exp_ovfl and exp(x) won't overflow (common case) + * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 + * - x = +-Inf (generated by exp()) + * - x = NaN (spurious inexact exception from y) + */ + exp_x = expf(x); + return cuComplex{exp_x * cosf(y), exp_x * sinf(y)}; + } +} + +} // namespace mlx::core::cu::detail diff --git a/mlx/backend/cuda/device/unary_ops.cuh b/mlx/backend/cuda/device/unary_ops.cuh index 18d769c2a..8716d3a8c 100644 --- a/mlx/backend/cuda/device/unary_ops.cuh +++ b/mlx/backend/cuda/device/unary_ops.cuh @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/cuda/device/cexpf.cuh" +#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/utils.cuh" @@ -150,8 +152,7 @@ struct Exp { template __device__ T operator()(T x) { if constexpr (cuda::std::is_same_v) { - auto m = exp(cuCrealf(x)); - return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))}; + return detail::cexpf(x); } else { return exp(x); } @@ -228,8 +229,25 @@ struct Log10 { struct Log1p { template - __device__ T operator()(T x) { - return log1p(x); + __device__ T operator()(T z) { + if constexpr (cuda::std::is_same_v) { + float x = cuCrealf(z); + float y = cuCimagf(z); + float zabs = cuCrealf(Abs{}(z)); + float theta = atan2f(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1pf(r), theta}; + } else { + float z0 = hypotf(x + 1, y); + return {logf(z0), theta}; + } + } else { + return log1p(z); + } } }; @@ -387,19 +405,19 @@ struct Tanh { } }; -__device__ cuComplex ArcCos::operator()(cuComplex x) { +inline __device__ cuComplex ArcCos::operator()(cuComplex x) { auto i = cuComplex{0.0, 1.0}; auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); return {cuCimagf(y), -cuCrealf(y)}; }; -__device__ cuComplex ArcSin::operator()(cuComplex x) { +inline __device__ cuComplex ArcSin::operator()(cuComplex x) { auto i = cuComplex{0.0f, 1.0f}; auto y = Log{}(i * x + Sqrt{}(1.0f - x * x)); return {cuCimagf(y), -cuCrealf(y)}; }; -__device__ cuComplex ArcTan::operator()(cuComplex x) { +inline __device__ cuComplex ArcTan::operator()(cuComplex x) { auto i = cuComplex{0.0f, 1.0f}; auto ix = i * x; return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix)); diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 83e149165..af022c141 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -359,21 +359,4 @@ struct LoopedElemToLoc<1, false, OffsetT> { } }; -inline __device__ cuComplex log1p(cuComplex in) { - float x = cuCrealf(in); - float y = cuCimagf(in); - float zabs = sqrt(x * x + y * y); - float theta = atan2f(y, x + 1); - if (zabs < 0.5f) { - float r = x * (2 + x) + y * y; - if (r == 0) { // handle underflow - return {x, theta}; - } - return {0.5f * log1pf(r), theta}; - } else { - auto z0 = sqrt((x + 1) * (x + 1) + y * y); - return {log(z0), theta}; - } -} - } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index e6dbd35da..834e4a3d1 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -161,6 +161,7 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "cast_op.cuh", + INCLUDE_PREFIX "cexpf.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "fp16_math.cuh", @@ -177,6 +178,7 @@ constexpr const char* g_headers[] = { jit_source_atomic_ops, jit_source_binary_ops, jit_source_cast_op, + jit_source_cexpf, jit_source_config, jit_source_cucomplex_math, jit_source_fp16_math, diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 1a9781c7c..969bc2ba7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -1350,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") { x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); CHECK(allclose(exp(x), expected).item()); + + // Complex of -inf + constexpr float inf = std::numeric_limits::infinity(); + x = array(complex64_t{-inf, -inf}); + CHECK_EQ(exp(x).item(), complex64_t{0, 0}); } // Test expm1 @@ -1830,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") { x = array(-inf); y = array(inf); CHECK_EQ(logaddexp(x, y).item(), inf); + + x = array(complex64_t{1, 1}); + y = array(complex64_t{-inf, -inf}); + CHECK_EQ(logaddexp(x, y).item(), complex64_t{1, 1}); } TEST_CASE("test broadcast") {