From 98b901ad665f7b3c57217fbc2cb4bae6bd548277 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 28 Mar 2025 10:45:13 -0700 Subject: [PATCH] enable complex gemm (#2017) --- mlx/backend/metal/kernels/gemv.metal | 2 +- mlx/ops.cpp | 21 ++++++++++++ python/tests/test_blas.py | 49 ++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal index f21c35d97..baaf84f2d 100644 --- a/mlx/backend/metal/kernels/gemv.metal +++ b/mlx/backend/metal/kernels/gemv.metal @@ -341,7 +341,7 @@ struct GEMVTKernel { MLX_MTL_PRAGMA_UNROLL for (int tm = 0; tm < TM; tm++) { - auto vc = float(v_coeff[tm]); + auto vc = static_cast(v_coeff[tm]); for (int tn = 0; tn < TN; tn++) { inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; } diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 5a64a7852..fe1852e0a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2826,6 +2826,19 @@ array matmul( } // Type promotion auto out_type = promote_types(a.dtype(), b.dtype()); + // Complex matmul in terms of real matmuls + if (out_type == complex64) { + auto a_real = real(a, s); + auto b_real = real(b, s); + auto a_imag = imag(a, s); + auto b_imag = imag(b, s); + auto c_real = + subtract(matmul(a_real, b_real, s), matmul(a_imag, b_imag, s), s); + auto c_imag = add(matmul(a_real, b_imag, s), matmul(a_imag, b_real, s), s); + return add( + c_real, multiply(array(complex64_t{0, 1}, complex64), c_imag, s), s); + } + if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[matmul] Only real floating point types are supported but " @@ -4150,6 +4163,14 @@ array addmm( // Type promotion auto out_type = result_type(a, b, c); + + if (out_type == complex64) { + return add( + multiply(matmul(a, b, s), array(alpha), s), + multiply(array(beta), c, s), + s); + } + if (!issubdtype(out_type, floating)) { std::ostringstream msg; msg << "[addmm] Only real floating point types are supported but " diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 985ca5ffb..c8627dbbd 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -1158,6 +1158,55 @@ class TestBlas(mlx_tests.MLXTestCase): out_gemm = (b @ c)[0] self.assertTrue(mx.allclose(out_gemv, out_gemm)) + def test_complex_gemv(self): + M = 16 + N = 50 + + def rand(shape): + return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) + + a = rand((M, N)) + b = rand((N, 1)) + c = mx.matmul(a, b) + c_np = np.matmul(a, b) + self.assertTrue(np.allclose(c, c_np)) + + # Transposed + a = rand((N, M)) + b = rand((N, 1)) + c = mx.matmul(a.T, b) + c_np = np.matmul(np.array(a).T, b) + self.assertTrue(np.allclose(c, c_np)) + + def test_complex_gemm(self): + M = 16 + K = 50 + N = 32 + + def rand(shape): + return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) + + a = rand((M, K)) + b = rand((K, N)) + c = mx.matmul(a, b) + c_np = np.matmul(a, b) + self.assertTrue(np.allclose(c, c_np)) + + # Test addmm + M = 16 + K = 50 + N = 32 + + def rand(shape): + return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape) + + a = rand((M, K)) + b = rand((K, N)) + c = rand((M, N)) + out = mx.addmm(c, a, b, 2.0, 2.0) + out_np = 2.0 * np.matmul(a, b) + 2.0 * c + self.assertTrue(np.allclose(out, out_np)) + if __name__ == "__main__": unittest.main()