From 8402a2acf4325a7213211dd7fcb4f397981ca695 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 13 Jun 2025 11:13:00 -0700 Subject: [PATCH] Fix complex power and print (#2286) * fix complex power and print * fix complex matmul shape --- mlx/backend/cuda/device/binary_ops.cuh | 7 ++++ mlx/backend/metal/kernels/binary_ops.h | 7 ++++ mlx/ops.cpp | 51 +++++++++++++------------- mlx/utils.cpp | 7 +++- python/tests/test_blas.py | 10 +++++ python/tests/test_ops.py | 7 ++++ 6 files changed, 63 insertions(+), 26 deletions(-) diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index 4779a6f33..b96a7f9cc 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -194,6 +194,13 @@ struct Power { } return res; } else if constexpr (cuda::std::is_same_v) { + if (base.y == 0 && base.x == 0) { + if (isnan(exp.x) || isnan(exp.y)) { + auto nan = cuda::std::numeric_limits::quiet_NaN(); + return make_cuFloatComplex(nan, nan); + } + return make_cuFloatComplex(0.0, 0.0); + } auto x_theta = atan2f(base.y, base.x); auto x_ln_r = 0.5 * logf(base.x * base.x + base.y * base.y); auto mag = expf(exp.x * x_ln_r - exp.y * x_theta); diff --git a/mlx/backend/metal/kernels/binary_ops.h b/mlx/backend/metal/kernels/binary_ops.h index 4aaf2b4da..f4deb860e 100644 --- a/mlx/backend/metal/kernels/binary_ops.h +++ b/mlx/backend/metal/kernels/binary_ops.h @@ -235,6 +235,13 @@ struct Power { template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9602f667a..2b861428f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2847,21 +2847,6 @@ array matmul( "[matmul] Got 0 dimension input. Inputs must " "have at least one dimension."); } - if (a.ndim() == 1) { - // Insert a singleton dim in the beginning - a = expand_dims(a, 0, s); - } - if (b.ndim() == 1) { - // Insert a singleton dim at the end - b = expand_dims(b, 1, s); - } - if (a.shape(-1) != b.shape(-2)) { - std::ostringstream msg; - msg << "[matmul] Last dimension of first input with shape " << a.shape() - << " must match second to last dimension of" - << " second input with shape " << b.shape() << "."; - throw std::invalid_argument(msg.str()); - } // complex matmul using Karatsuba's Algorithm if (a.dtype() == complex64 || b.dtype() == complex64) { @@ -2883,6 +2868,22 @@ array matmul( c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); } + if (a.ndim() == 1) { + // Insert a singleton dim in the beginning + a = expand_dims(a, 0, s); + } + if (b.ndim() == 1) { + // Insert a singleton dim at the end + b = expand_dims(b, 1, s); + } + if (a.shape(-1) != b.shape(-2)) { + std::ostringstream msg; + msg << "[matmul] Last dimension of first input with shape " << a.shape() + << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); @@ -4240,6 +4241,16 @@ array addmm( "have at least one dimension."); } + // Type promotion + auto out_type = result_type(a, b, c); + + if (out_type == complex64) { + return add( + multiply(matmul(a, b, s), array(alpha), s), + multiply(array(beta), c, s), + s); + } + if (a.ndim() == 1) { // Insert a singleton dim in the beginning a = expand_dims(a, 0, s); @@ -4257,16 +4268,6 @@ array addmm( throw std::invalid_argument(msg.str()); } - // Type promotion - auto out_type = result_type(a, b, c); - - if (out_type == complex64) { - return add( - multiply(matmul(a, b, s), array(alpha), s), - multiply(array(beta), c, s), - s); - } - if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[addmm] Only real floating point types are supported but " diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 0b2e66352..61b9da3a2 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -69,7 +69,12 @@ inline void PrintFormatter::print(std::ostream& os, double val) { os << val; } inline void PrintFormatter::print(std::ostream& os, complex64_t val) { - os << val; + os << val.real(); + if (val.imag() >= 0 || std::isnan(val.imag())) { + os << "+" << val.imag() << "j"; + } else { + os << "-" << -val.imag() << "j"; + } } PrintFormatter& get_global_formatter() { diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 8c7a97ba8..2762df8f8 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase): c_np = np.matmul(np.array(a).T, b) self.assertTrue(np.allclose(c, c_np)) + # Check shapes + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + self.assertEqual((a @ b).shape, (2,)) + + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + c = mx.random.normal((2,)) + self.assertEqual(mx.addmm(c, a, b).shape, (2,)) + def test_complex_gemm(self): M = 16 K = 50 diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index f3d48dda3..7c4f3f8e3 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -3078,6 +3078,13 @@ class TestOps(mlx_tests.MLXTestCase): ) self.assertTrue(np.allclose(mx.rsqrt(x), 1.0 / np.sqrt(x))) + def test_complex_power(self): + out = mx.power(mx.array(0j), 2) + self.assertEqual(out.item(), 0j) + + out = mx.power(mx.array(0j), float("nan")) + self.assertTrue(mx.isnan(out)) + class TestBroadcast(mlx_tests.MLXTestCase): def test_broadcast_shapes(self):