mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
enable complex gemm (#2017)
This commit is contained in:
@@ -1158,6 +1158,55 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
out_gemm = (b @ c)[0]
|
||||
self.assertTrue(mx.allclose(out_gemv, out_gemm))
|
||||
|
||||
def test_complex_gemv(self):
|
||||
M = 16
|
||||
N = 50
|
||||
|
||||
def rand(shape):
|
||||
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||
|
||||
a = rand((M, N))
|
||||
b = rand((N, 1))
|
||||
c = mx.matmul(a, b)
|
||||
c_np = np.matmul(a, b)
|
||||
self.assertTrue(np.allclose(c, c_np))
|
||||
|
||||
# Transposed
|
||||
a = rand((N, M))
|
||||
b = rand((N, 1))
|
||||
c = mx.matmul(a.T, b)
|
||||
c_np = np.matmul(np.array(a).T, b)
|
||||
self.assertTrue(np.allclose(c, c_np))
|
||||
|
||||
def test_complex_gemm(self):
|
||||
M = 16
|
||||
K = 50
|
||||
N = 32
|
||||
|
||||
def rand(shape):
|
||||
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||
|
||||
a = rand((M, K))
|
||||
b = rand((K, N))
|
||||
c = mx.matmul(a, b)
|
||||
c_np = np.matmul(a, b)
|
||||
self.assertTrue(np.allclose(c, c_np))
|
||||
|
||||
# Test addmm
|
||||
M = 16
|
||||
K = 50
|
||||
N = 32
|
||||
|
||||
def rand(shape):
|
||||
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
|
||||
|
||||
a = rand((M, K))
|
||||
b = rand((K, N))
|
||||
c = rand((M, N))
|
||||
out = mx.addmm(c, a, b, 2.0, 2.0)
|
||||
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
|
||||
self.assertTrue(np.allclose(out, out_np))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user