This commit is contained in:
Awni Hannun
2025-06-02 13:43:22 -07:00
parent f859e75f4f
commit d4aafb9161
2 changed files with 8 additions and 9 deletions

View File

@@ -1210,13 +1210,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(c, c_np))
# Test addmm
M = 16
K = 50
N = 32
def rand(shape):
return mx.random.uniform(shape=shape) + 1j * mx.random.uniform(shape=shape)
a = rand((M, K))
b = rand((K, N))
c = rand((M, N))
@@ -1224,6 +1217,13 @@ class TestBlas(mlx_tests.MLXTestCase):
out_np = 2.0 * np.matmul(a, b) + 2.0 * c
self.assertTrue(np.allclose(out, out_np))
# complex with real
a = rand((M, K)).real
b = rand((K, N))
c = mx.matmul(a, b)
c_np = np.matmul(a, b)
self.assertTrue(np.allclose(out, out_np))
if __name__ == "__main__":
unittest.main()