enable complex gemm (#2017)

This commit is contained in:
Awni Hannun
2025-03-28 10:45:13 -07:00
committed by GitHub
parent 5580b47291
commit 98b901ad66
3 changed files with 71 additions and 1 deletions

View File

@@ -1158,6 +1158,55 @@ class TestBlas(mlx_tests.MLXTestCase):
out_gemm = (b @ c)[0]
self.assertTrue(mx.allclose(out_gemv, out_gemm))
def test_complex_gemv(self):
M = 16
N = 50
def rand(shape):
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
a = rand((M, N))
b = rand((N, 1))
c = mx.matmul(a, b)
c_np = np.matmul(a, b)
self.assertTrue(np.allclose(c, c_np))
# Transposed
a = rand((N, M))
b = rand((N, 1))
c = mx.matmul(a.T, b)
c_np = np.matmul(np.array(a).T, b)
self.assertTrue(np.allclose(c, c_np))
def test_complex_gemm(self):
M = 16
K = 50
N = 32
def rand(shape):
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
a = rand((M, K))
b = rand((K, N))
c = mx.matmul(a, b)
c_np = np.matmul(a, b)
self.assertTrue(np.allclose(c, c_np))
# Test addmm
M = 16
K = 50
N = 32
def rand(shape):
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
a = rand((M, K))
b = rand((K, N))
c = rand((M, N))
out = mx.addmm(c, a, b, 2.0, 2.0)
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
self.assertTrue(np.allclose(out, out_np))
if __name__ == "__main__":
unittest.main()