From 628e36f7d926b08d0f5d58d5da7302617c934477 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 13 Jun 2025 07:42:12 -0700 Subject: [PATCH] fix complex matmul shape --- mlx/backend/cuda/device/binary_ops.cuh | 2 +- mlx/ops.cpp | 51 +++++++++++++------------- python/tests/test_blas.py | 10 +++++ 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/mlx/backend/cuda/device/binary_ops.cuh b/mlx/backend/cuda/device/binary_ops.cuh index 20e795dfe..b96a7f9cc 100644 --- a/mlx/backend/cuda/device/binary_ops.cuh +++ b/mlx/backend/cuda/device/binary_ops.cuh @@ -196,7 +196,7 @@ struct Power { } else if constexpr (cuda::std::is_same_v) { if (base.y == 0 && base.x == 0) { if (isnan(exp.x) || isnan(exp.y)) { - auto nan = cuda::std::numeric_limits::quiet_NaN(); + auto nan = cuda::std::numeric_limits::quiet_NaN(); return make_cuFloatComplex(nan, nan); } return make_cuFloatComplex(0.0, 0.0); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 9602f667a..2b861428f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2847,21 +2847,6 @@ array matmul( "[matmul] Got 0 dimension input. Inputs must " "have at least one dimension."); } - if (a.ndim() == 1) { - // Insert a singleton dim in the beginning - a = expand_dims(a, 0, s); - } - if (b.ndim() == 1) { - // Insert a singleton dim at the end - b = expand_dims(b, 1, s); - } - if (a.shape(-1) != b.shape(-2)) { - std::ostringstream msg; - msg << "[matmul] Last dimension of first input with shape " << a.shape() - << " must match second to last dimension of" - << " second input with shape " << b.shape() << "."; - throw std::invalid_argument(msg.str()); - } // complex matmul using Karatsuba's Algorithm if (a.dtype() == complex64 || b.dtype() == complex64) { @@ -2883,6 +2868,22 @@ array matmul( c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); } + if (a.ndim() == 1) { + // Insert a singleton dim in the beginning + a = expand_dims(a, 0, s); + } + if (b.ndim() == 1) { + // Insert a singleton dim at the end + b = expand_dims(b, 1, s); + } + if (a.shape(-1) != b.shape(-2)) { + std::ostringstream msg; + msg << "[matmul] Last dimension of first input with shape " << a.shape() + << " must match second to last dimension of" + << " second input with shape " << b.shape() << "."; + throw std::invalid_argument(msg.str()); + } + // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); @@ -4240,6 +4241,16 @@ array addmm( "have at least one dimension."); } + // Type promotion + auto out_type = result_type(a, b, c); + + if (out_type == complex64) { + return add( + multiply(matmul(a, b, s), array(alpha), s), + multiply(array(beta), c, s), + s); + } + if (a.ndim() == 1) { // Insert a singleton dim in the beginning a = expand_dims(a, 0, s); @@ -4257,16 +4268,6 @@ array addmm( throw std::invalid_argument(msg.str()); } - // Type promotion - auto out_type = result_type(a, b, c); - - if (out_type == complex64) { - return add( - multiply(matmul(a, b, s), array(alpha), s), - multiply(array(beta), c, s), - s); - } - if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[addmm] Only real floating point types are supported but " diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 8c7a97ba8..2762df8f8 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1195,6 +1195,16 @@ class TestBlas(mlx_tests.MLXTestCase): c_np = np.matmul(np.array(a).T, b) self.assertTrue(np.allclose(c, c_np)) + # Check shapes + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + self.assertEqual((a @ b).shape, (2,)) + + a = mx.random.normal((2, 3)).astype(mx.complex64) + b = mx.random.normal((3,)) + c = mx.random.normal((2,)) + self.assertEqual(mx.addmm(c, a, b).shape, (2,)) + def test_complex_gemm(self): M = 16 K = 50