mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix failing logaddexp test
This commit is contained in:
@@ -1,10 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
#include "mlx/backend/cuda/device/unary_ops.cuh"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
|
||||||
|
|
||||||
#include <cuComplex.h>
|
|
||||||
#include <cuda/std/array>
|
#include <cuda/std/array>
|
||||||
|
|
||||||
namespace mlx::core::cu {
|
namespace mlx::core::cu {
|
||||||
@@ -114,36 +111,38 @@ struct LessEqual {
|
|||||||
struct LogAddExp {
|
struct LogAddExp {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x, T y) {
|
__device__ T operator()(T x, T y) {
|
||||||
if (isnan(x) || isnan(y)) {
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
return cuda::std::numeric_limits<T>::quiet_NaN();
|
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
||||||
|
isnan(cuCimagf(y))) {
|
||||||
|
return {
|
||||||
|
cuda::std::numeric_limits<float>::quiet_NaN(),
|
||||||
|
cuda::std::numeric_limits<float>::quiet_NaN()};
|
||||||
|
}
|
||||||
|
auto max = cuCrealf(x) > cuCrealf(y) ? x : y;
|
||||||
|
auto min = cuCrealf(x) < cuCrealf(y) ? x : y;
|
||||||
|
auto min_real = cuCrealf(min);
|
||||||
|
auto max_real = cuCrealf(max);
|
||||||
|
if (!isfinite(min_real) && (min_real == max_real)) {
|
||||||
|
if (min_real < 0) {
|
||||||
|
return min;
|
||||||
|
} else {
|
||||||
|
return Log{}(Exp{}(min) + Exp{}(max));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Log1p{}(Exp{}(min - max)) + max;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (isnan(x) || isnan(y)) {
|
||||||
|
return cuda::std::numeric_limits<T>::quiet_NaN();
|
||||||
|
}
|
||||||
|
T maxval = max(x, y);
|
||||||
|
T minval = min(x, y);
|
||||||
|
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
||||||
|
maxval == cuda::std::numeric_limits<T>::infinity())
|
||||||
|
? maxval
|
||||||
|
: T(float(maxval) + log1p(expf(minval - maxval)));
|
||||||
}
|
}
|
||||||
T maxval = max(x, y);
|
|
||||||
T minval = min(x, y);
|
|
||||||
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
|
|
||||||
maxval == cuda::std::numeric_limits<T>::infinity())
|
|
||||||
? maxval
|
|
||||||
: T(float(maxval) + log1p(expf(minval - maxval)));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
|
|
||||||
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
|
|
||||||
isnan(cuCimagf(y))) {
|
|
||||||
return {
|
|
||||||
cuda::std::numeric_limits<float>::quiet_NaN(),
|
|
||||||
cuda::std::numeric_limits<float>::quiet_NaN()};
|
|
||||||
}
|
|
||||||
float inf = cuda::std::numeric_limits<float>::infinity();
|
|
||||||
auto maxval = x > y ? x : y;
|
|
||||||
auto minval = x < y ? x : y;
|
|
||||||
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
|
|
||||||
return maxval;
|
|
||||||
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
|
|
||||||
cuComplex dexp{
|
|
||||||
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
|
|
||||||
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
|
|
||||||
};
|
|
||||||
return maxval + log1p(dexp);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Maximum {
|
struct Maximum {
|
||||||
|
|||||||
138
mlx/backend/cuda/device/cexpf.cuh
Normal file
138
mlx/backend/cuda/device/cexpf.cuh
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
// Copyright © 2008-2013 NVIDIA Corporation
|
||||||
|
// Copyright © 2013 Filipe RNC Maia
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Forked from
|
||||||
|
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
|
||||||
|
|
||||||
|
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
|
||||||
|
// can not be used in JIT.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuComplex.h>
|
||||||
|
#include <cuda/std/cstdint>
|
||||||
|
|
||||||
|
namespace mlx::core::cu::detail {
|
||||||
|
|
||||||
|
using ieee_float_shape_type = union {
|
||||||
|
float value;
|
||||||
|
uint32_t word;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline __device__ void get_float_word(uint32_t& i, float d) {
|
||||||
|
ieee_float_shape_type gf_u;
|
||||||
|
gf_u.value = (d);
|
||||||
|
(i) = gf_u.word;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void get_float_word(int32_t& i, float d) {
|
||||||
|
ieee_float_shape_type gf_u;
|
||||||
|
gf_u.value = (d);
|
||||||
|
(i) = gf_u.word;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ void set_float_word(float& d, uint32_t i) {
|
||||||
|
ieee_float_shape_type sf_u;
|
||||||
|
sf_u.word = (i);
|
||||||
|
(d) = sf_u.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ float frexp_expf(float x, int* expt) {
|
||||||
|
const uint32_t k = 235;
|
||||||
|
const float kln2 = 162.88958740F;
|
||||||
|
|
||||||
|
float exp_x;
|
||||||
|
uint32_t hx;
|
||||||
|
|
||||||
|
exp_x = expf(x - kln2);
|
||||||
|
get_float_word(hx, exp_x);
|
||||||
|
*expt = (hx >> 23) - (0x7f + 127) + k;
|
||||||
|
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
|
||||||
|
return exp_x;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) {
|
||||||
|
float x, y, exp_x, scale1, scale2;
|
||||||
|
int ex_expt, half_expt;
|
||||||
|
|
||||||
|
x = cuCrealf(z);
|
||||||
|
y = cuCimagf(z);
|
||||||
|
exp_x = frexp_expf(x, &ex_expt);
|
||||||
|
expt += ex_expt;
|
||||||
|
|
||||||
|
half_expt = expt / 2;
|
||||||
|
set_float_word(scale1, (0x7f + half_expt) << 23);
|
||||||
|
half_expt = expt - half_expt;
|
||||||
|
set_float_word(scale2, (0x7f + half_expt) << 23);
|
||||||
|
|
||||||
|
return cuComplex{
|
||||||
|
cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2};
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ cuComplex cexpf(const cuComplex& z) {
|
||||||
|
float x, y, exp_x;
|
||||||
|
uint32_t hx, hy;
|
||||||
|
|
||||||
|
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
|
||||||
|
|
||||||
|
x = cuCrealf(z);
|
||||||
|
y = cuCimagf(z);
|
||||||
|
|
||||||
|
get_float_word(hy, y);
|
||||||
|
hy &= 0x7fffffff;
|
||||||
|
|
||||||
|
/* cexp(x + I 0) = exp(x) + I 0 */
|
||||||
|
if (hy == 0) {
|
||||||
|
return cuComplex{expf(x), y};
|
||||||
|
}
|
||||||
|
get_float_word(hx, x);
|
||||||
|
/* cexp(0 + I y) = cos(y) + I sin(y) */
|
||||||
|
if ((hx & 0x7fffffff) == 0) {
|
||||||
|
return cuComplex{cosf(y), sinf(y)};
|
||||||
|
}
|
||||||
|
if (hy >= 0x7f800000) {
|
||||||
|
if ((hx & 0x7fffffff) != 0x7f800000) {
|
||||||
|
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
|
||||||
|
return cuComplex{y - y, y - y};
|
||||||
|
} else if (hx & 0x80000000) {
|
||||||
|
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
|
||||||
|
return cuComplex{0.0, 0.0};
|
||||||
|
} else {
|
||||||
|
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
|
||||||
|
return cuComplex{x, y - y};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
|
||||||
|
/*
|
||||||
|
* x is between 88.7 and 192, so we must scale to avoid
|
||||||
|
* overflow in expf(x).
|
||||||
|
*/
|
||||||
|
return ldexp_cexpf(z, 0);
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
* Cases covered here:
|
||||||
|
* - x < exp_ovfl and exp(x) won't overflow (common case)
|
||||||
|
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
|
||||||
|
* - x = +-Inf (generated by exp())
|
||||||
|
* - x = NaN (spurious inexact exception from y)
|
||||||
|
*/
|
||||||
|
exp_x = expf(x);
|
||||||
|
return cuComplex{exp_x * cosf(y), exp_x * sinf(y)};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::cu::detail
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device/cexpf.cuh"
|
||||||
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
#include "mlx/backend/cuda/device/fp16_math.cuh"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
@@ -150,8 +152,7 @@ struct Exp {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T x) {
|
||||||
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
auto m = exp(cuCrealf(x));
|
return detail::cexpf(x);
|
||||||
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
|
|
||||||
} else {
|
} else {
|
||||||
return exp(x);
|
return exp(x);
|
||||||
}
|
}
|
||||||
@@ -228,8 +229,25 @@ struct Log10 {
|
|||||||
|
|
||||||
struct Log1p {
|
struct Log1p {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ T operator()(T x) {
|
__device__ T operator()(T z) {
|
||||||
return log1p(x);
|
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
|
||||||
|
float x = cuCrealf(z);
|
||||||
|
float y = cuCimagf(z);
|
||||||
|
float zabs = cuCrealf(Abs{}(z));
|
||||||
|
float theta = atan2f(y, x + 1);
|
||||||
|
if (zabs < 0.5f) {
|
||||||
|
float r = x * (2 + x) + y * y;
|
||||||
|
if (r == 0) { // handle underflow
|
||||||
|
return {x, theta};
|
||||||
|
}
|
||||||
|
return {0.5f * log1pf(r), theta};
|
||||||
|
} else {
|
||||||
|
float z0 = hypotf(x + 1, y);
|
||||||
|
return {logf(z0), theta};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return log1p(z);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -387,19 +405,19 @@ struct Tanh {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
__device__ cuComplex ArcCos::operator()(cuComplex x) {
|
inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
|
||||||
auto i = cuComplex{0.0, 1.0};
|
auto i = cuComplex{0.0, 1.0};
|
||||||
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
|
||||||
return {cuCimagf(y), -cuCrealf(y)};
|
return {cuCimagf(y), -cuCrealf(y)};
|
||||||
};
|
};
|
||||||
|
|
||||||
__device__ cuComplex ArcSin::operator()(cuComplex x) {
|
inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
|
||||||
auto i = cuComplex{0.0f, 1.0f};
|
auto i = cuComplex{0.0f, 1.0f};
|
||||||
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
|
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
|
||||||
return {cuCimagf(y), -cuCrealf(y)};
|
return {cuCimagf(y), -cuCrealf(y)};
|
||||||
};
|
};
|
||||||
|
|
||||||
__device__ cuComplex ArcTan::operator()(cuComplex x) {
|
inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
|
||||||
auto i = cuComplex{0.0f, 1.0f};
|
auto i = cuComplex{0.0f, 1.0f};
|
||||||
auto ix = i * x;
|
auto ix = i * x;
|
||||||
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
|
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));
|
||||||
|
|||||||
@@ -359,21 +359,4 @@ struct LoopedElemToLoc<1, false, OffsetT> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
inline __device__ cuComplex log1p(cuComplex in) {
|
|
||||||
float x = cuCrealf(in);
|
|
||||||
float y = cuCimagf(in);
|
|
||||||
float zabs = sqrt(x * x + y * y);
|
|
||||||
float theta = atan2f(y, x + 1);
|
|
||||||
if (zabs < 0.5f) {
|
|
||||||
float r = x * (2 + x) + y * y;
|
|
||||||
if (r == 0) { // handle underflow
|
|
||||||
return {x, theta};
|
|
||||||
}
|
|
||||||
return {0.5f * log1pf(r), theta};
|
|
||||||
} else {
|
|
||||||
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
|
|
||||||
return {log(z0), theta};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core::cu
|
} // namespace mlx::core::cu
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ constexpr const char* g_include_names[] = {
|
|||||||
INCLUDE_PREFIX "atomic_ops.cuh",
|
INCLUDE_PREFIX "atomic_ops.cuh",
|
||||||
INCLUDE_PREFIX "binary_ops.cuh",
|
INCLUDE_PREFIX "binary_ops.cuh",
|
||||||
INCLUDE_PREFIX "cast_op.cuh",
|
INCLUDE_PREFIX "cast_op.cuh",
|
||||||
|
INCLUDE_PREFIX "cexpf.cuh",
|
||||||
INCLUDE_PREFIX "config.h",
|
INCLUDE_PREFIX "config.h",
|
||||||
INCLUDE_PREFIX "cucomplex_math.cuh",
|
INCLUDE_PREFIX "cucomplex_math.cuh",
|
||||||
INCLUDE_PREFIX "fp16_math.cuh",
|
INCLUDE_PREFIX "fp16_math.cuh",
|
||||||
@@ -177,6 +178,7 @@ constexpr const char* g_headers[] = {
|
|||||||
jit_source_atomic_ops,
|
jit_source_atomic_ops,
|
||||||
jit_source_binary_ops,
|
jit_source_binary_ops,
|
||||||
jit_source_cast_op,
|
jit_source_cast_op,
|
||||||
|
jit_source_cexpf,
|
||||||
jit_source_config,
|
jit_source_config,
|
||||||
jit_source_cucomplex_math,
|
jit_source_cucomplex_math,
|
||||||
jit_source_fp16_math,
|
jit_source_fp16_math,
|
||||||
|
|||||||
@@ -1350,6 +1350,11 @@ TEST_CASE("test arithmetic unary ops") {
|
|||||||
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
|
x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0];
|
||||||
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
|
auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1});
|
||||||
CHECK(allclose(exp(x), expected).item<bool>());
|
CHECK(allclose(exp(x), expected).item<bool>());
|
||||||
|
|
||||||
|
// Complex of -inf
|
||||||
|
constexpr float inf = std::numeric_limits<float>::infinity();
|
||||||
|
x = array(complex64_t{-inf, -inf});
|
||||||
|
CHECK_EQ(exp(x).item<complex64_t>(), complex64_t{0, 0});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test expm1
|
// Test expm1
|
||||||
@@ -1830,6 +1835,10 @@ TEST_CASE("test arithmetic binary ops") {
|
|||||||
x = array(-inf);
|
x = array(-inf);
|
||||||
y = array(inf);
|
y = array(inf);
|
||||||
CHECK_EQ(logaddexp(x, y).item<float>(), inf);
|
CHECK_EQ(logaddexp(x, y).item<float>(), inf);
|
||||||
|
|
||||||
|
x = array(complex64_t{1, 1});
|
||||||
|
y = array(complex64_t{-inf, -inf});
|
||||||
|
CHECK_EQ(logaddexp(x, y).item<complex64_t>(), complex64_t{1, 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test broadcast") {
|
TEST_CASE("test broadcast") {
|
||||||
|
|||||||
Reference in New Issue
Block a user