diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index 4779a6f33..20e795dfe 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -194,6 +194,13 @@ struct Power { } return res; } else if constexpr (cuda::std::is_same_v) { + if (base.y == 0 && base.x == 0) { + if (isnan(exp.x) || isnan(exp.y)) { + auto nan = cuda::std::numeric_limits::quiet_NaN(); + return make_cuFloatComplex(nan, nan); + } + return make_cuFloatComplex(0.0, 0.0); + } auto x_theta = atan2f(base.y, base.x); auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 4aaf2b4da..f4deb860e 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -235,6 +235,13 @@ struct Power { template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 0b2e66352..61b9da3a2 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) { os << val; } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - os << val; + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } } PrintFormatter& get_global_formatter() { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index f3d48dda3..7c4f3f8e3 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase): ) self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) + def test_complex_power(self): + out = mx.power(mx.array(0j), 2) + self.assertEqual(out.item(), 0j) + + out = mx.power(mx.array(0j), float("nan")) + self.assertTrue(mx.isnan(out)) + class TestBroadcast(mlx_tests.MLXTestCase): def test_broadcast_shapes(self):