Optimizing Complex Matrix Multiplication using Karatsuba’s Algorithm (#2220)

* Implementing Complex Matmul using Karatsuba Algorithm

* Implemented Karatsuba's Algorithm for complex matmul and pre-commit them

* fix

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Suryash Malviya
2025-06-02 18:58:46 -04:00
committed by GitHub
parent cbad6c3093
commit 0408ba0a76
2 changed files with 24 additions and 15 deletions

View File

@@ -1210,13 +1210,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(c, c_np))
# Test addmm
M = 16
K = 50
N = 32
def rand(shape):
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
a = rand((M, K))
b = rand((K, N))
c = rand((M, N))
@@ -1224,6 +1217,13 @@ class TestBlas(mlx_tests.MLXTestCase):
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
self.assertTrue(np.allclose(out, out_np))
# complex with real
a = rand((M, K)).real
b = rand((K, N))
c = mx.matmul(a, b)
c_np = np.matmul(a, b)
self.assertTrue(np.allclose(out, out_np))
if __name__ == "__main__":
unittest.main()