Option to JIT steel gemm / conv (#1139)

This commit is contained in:
Awni Hannun
2024-05-23 18:07:34 -07:00
committed by GitHub
parent eab2685c67
commit 7e26fd8032
31 changed files with 2504 additions and 1540 deletions

View File

@@ -556,8 +556,7 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# # Batched matmul with simple broadcast
# Batched matmul with simple broadcast
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
@@ -573,7 +572,6 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Matmul with vector
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)