Fix failing logaddexp test

This commit is contained in:
Cheng
2025-07-10 02:13:24 +00:00
parent f797b1b3e5
commit 3564913327
6 changed files with 205 additions and 56 deletions

View File

@@ -1,10 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/unary_ops.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh"
#include <cuComplex.h>
#include <cuda/std/array> #include <cuda/std/array>
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -114,36 +111,38 @@ struct LessEqual {
struct LogAddExp { struct LogAddExp {
template <typename T> template <typename T>
__device__ T operator()(T x, T y) { __device__ T operator()(T x, T y) {
if (isnan(x) || isnan(y)) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
return cuda::std::numeric_limits<T>::quiet_NaN(); if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
auto max = cuCrealf(x) > cuCrealf(y) ? x : y;
auto min = cuCrealf(x) < cuCrealf(y) ? x : y;
auto min_real = cuCrealf(min);
auto max_real = cuCrealf(max);
if (!isfinite(min_real) && (min_real == max_real)) {
if (min_real < 0) {
return min;
} else {
return Log{}(Exp{}(min) + Exp{}(max));
}
} else {
return Log1p{}(Exp{}(min - max)) + max;
}
} else {
if (isnan(x) || isnan(y)) {
return cuda::std::numeric_limits<T>::quiet_NaN();
}
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
} }
T maxval = max(x, y);
T minval = min(x, y);
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
maxval == cuda::std::numeric_limits<T>::infinity())
? maxval
: T(float(maxval) + log1p(expf(minval - maxval)));
}; };
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
isnan(cuCimagf(y))) {
return {
cuda::std::numeric_limits<float>::quiet_NaN(),
cuda::std::numeric_limits<float>::quiet_NaN()};
}
float inf = cuda::std::numeric_limits<float>::infinity();
auto maxval = x > y ? x : y;
auto minval = x < y ? x : y;
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
return maxval;
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
cuComplex dexp{
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
};
return maxval + log1p(dexp);
}
}; };
struct Maximum { struct Maximum {

View File

@@ -0,0 +1,138 @@
// Copyright © 2025 Apple Inc.
// Copyright © 2008-2013 NVIDIA Corporation
// Copyright © 2013 Filipe RNC Maia
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
// can not be used in JIT.
#pragma once
#include <cuComplex.h>
#include <cuda/std/cstdint>
namespace mlx::core::cu::detail {
using ieee_float_shape_type = union {
float value;
uint32_t word;
};
inline __device__ void get_float_word(uint32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline __device__ void get_float_word(int32_t& i, float d) {
ieee_float_shape_type gf_u;
gf_u.value = (d);
(i) = gf_u.word;
}
inline __device__ void set_float_word(float& d, uint32_t i) {
ieee_float_shape_type sf_u;
sf_u.word = (i);
(d) = sf_u.value;
}
inline __device__ float frexp_expf(float x, int* expt) {
const uint32_t k = 235;
const float kln2 = 162.88958740F;
float exp_x;
uint32_t hx;
exp_x = expf(x - kln2);
get_float_word(hx, exp_x);
*expt = (hx >> 23) - (0x7f + 127) + k;
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
return exp_x;
}
inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) {
float x, y, exp_x, scale1, scale2;
int ex_expt, half_expt;
x = cuCrealf(z);
y = cuCimagf(z);
exp_x = frexp_expf(x, &ex_expt);
expt += ex_expt;
half_expt = expt / 2;
set_float_word(scale1, (0x7f + half_expt) << 23);
half_expt = expt - half_expt;
set_float_word(scale2, (0x7f + half_expt) << 23);
return cuComplex{
cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2};
}
inline __device__ cuComplex cexpf(const cuComplex& z) {
float x, y, exp_x;
uint32_t hx, hy;
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
x = cuCrealf(z);
y = cuCimagf(z);
get_float_word(hy, y);
hy &= 0x7fffffff;
/* cexp(x + I 0) = exp(x) + I 0 */
if (hy == 0) {
return cuComplex{expf(x), y};
}
get_float_word(hx, x);
/* cexp(0 + I y) = cos(y) + I sin(y) */
if ((hx & 0x7fffffff) == 0) {
return cuComplex{cosf(y), sinf(y)};
}
if (hy >= 0x7f800000) {
if ((hx & 0x7fffffff) != 0x7f800000) {
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
return cuComplex{y - y, y - y};
} else if (hx & 0x80000000) {
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
return cuComplex{0.0, 0.0};
} else {
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
return cuComplex{x, y - y};
}
}
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
/*
* x is between 88.7 and 192, so we must scale to avoid
* overflow in expf(x).
*/
return ldexp_cexpf(z, 0);
} else {
/*
* Cases covered here:
* - x < exp_ovfl and exp(x) won't overflow (common case)
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
* - x = +-Inf (generated by exp())
* - x = NaN (spurious inexact exception from y)
*/
exp_x = expf(x);
return cuComplex{exp_x * cosf(y), exp_x * sinf(y)};
}
}
} // namespace mlx::core::cu::detail

View File

@@ -2,6 +2,8 @@
#pragma once #pragma once
#include "mlx/backend/cuda/device/cexpf.cuh"
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/device/fp16_math.cuh" #include "mlx/backend/cuda/device/fp16_math.cuh"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
@@ -150,8 +152,7 @@ struct Exp {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T x) {
if constexpr (cuda::std::is_same_v<T, cuComplex>) { if constexpr (cuda::std::is_same_v<T, cuComplex>) {
auto m = exp(cuCrealf(x)); return detail::cexpf(x);
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
} else { } else {
return exp(x); return exp(x);
} }
@@ -228,8 +229,25 @@ struct Log10 {
struct Log1p { struct Log1p {
template <typename T> template <typename T>
__device__ T operator()(T x) { __device__ T operator()(T z) {
return log1p(x); if constexpr (cuda::std::is_same_v<T, cuComplex>) {
float x = cuCrealf(z);
float y = cuCimagf(z);
float zabs = cuCrealf(Abs{}(z));
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
float z0 = hypotf(x + 1, y);
return {logf(z0), theta};
}
} else {
return log1p(z);
}
} }
}; };
@@ -387,19 +405,19 @@ struct Tanh {
} }
}; };
__device__ cuComplex ArcCos::operator()(cuComplex x) { inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
auto i = cuComplex{0.0, 1.0}; auto i = cuComplex{0.0, 1.0};
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
return {cuCimagf(y), -cuCrealf(y)}; return {cuCimagf(y), -cuCrealf(y)};
}; };
__device__ cuComplex ArcSin::operator()(cuComplex x) { inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f}; auto i = cuComplex{0.0f, 1.0f};
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x)); auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
return {cuCimagf(y), -cuCrealf(y)}; return {cuCimagf(y), -cuCrealf(y)};
}; };
__device__ cuComplex ArcTan::operator()(cuComplex x) { inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
auto i = cuComplex{0.0f, 1.0f}; auto i = cuComplex{0.0f, 1.0f};
auto ix = i * x; auto ix = i * x;
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix)); return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));

View File

@@ -359,21 +359,4 @@ struct LoopedElemToLoc<1, false, OffsetT> {
} }
}; };
inline __device__ cuComplex log1p(cuComplex in) {
float x = cuCrealf(in);
float y = cuCimagf(in);
float zabs = sqrt(x * x + y * y);
float theta = atan2f(y, x + 1);
if (zabs < 0.5f) {
float r = x * (2 + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {0.5f * log1pf(r), theta};
} else {
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
return {log(z0), theta};
}
}
} // namespace mlx::core::cu } // namespace mlx::core::cu

View File

@@ -161,6 +161,7 @@ constexpr const char* g_include_names[] = {
INCLUDE_PREFIX "atomic_ops.cuh", INCLUDE_PREFIX "atomic_ops.cuh",
INCLUDE_PREFIX "binary_ops.cuh", INCLUDE_PREFIX "binary_ops.cuh",
INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "cast_op.cuh",
INCLUDE_PREFIX "cexpf.cuh",
INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "config.h",
INCLUDE_PREFIX "cucomplex_math.cuh", INCLUDE_PREFIX "cucomplex_math.cuh",
INCLUDE_PREFIX "fp16_math.cuh", INCLUDE_PREFIX "fp16_math.cuh",
@@ -177,6 +178,7 @@ constexpr const char* g_headers[] = {
jit_source_atomic_ops, jit_source_atomic_ops,
jit_source_binary_ops, jit_source_binary_ops,
jit_source_cast_op, jit_source_cast_op,
jit_source_cexpf,
jit_source_config, jit_source_config,
jit_source_cucomplex_math, jit_source_cucomplex_math,
jit_source_fp16_math, jit_source_fp16_math,

View File

@@ -1350,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") {
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
CHECK(allclose(exp(x), expected).item<bool>()); CHECK(allclose(exp(x), expected).item<bool>());
// Complex of -inf
constexpr float inf = std::numeric_limits<float>::infinity();
x = array(complex64_t{-inf, -inf});
CHECK_EQ(exp(x).item<complex64_t>(), complex64_t{0, 0});
} }
// Test expm1 // Test expm1
@@ -1830,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") {
x = array(-inf); x = array(-inf);
y = array(inf); y = array(inf);
CHECK_EQ(logaddexp(x, y).item<float>(), inf); CHECK_EQ(logaddexp(x, y).item<float>(), inf);
x = array(complex64_t{1, 1});
y = array(complex64_t{-inf, -inf});
CHECK_EQ(logaddexp(x, y).item<complex64_t>(), complex64_t{1, 1});
} }
TEST_CASE("test broadcast") { TEST_CASE("test broadcast") {