mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Option to JIT steel gemm / conv (#1139)
This commit is contained in:
@@ -556,8 +556,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# # Batched matmul with simple broadcast
|
||||
# Batched matmul with simple broadcast
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||
|
||||
@@ -573,7 +572,6 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Matmul with vector
|
||||
a_npy = np.random.normal(0.0, 1.0 / 128, (16,)).astype(np.float32)
|
||||
b_npy = np.random.normal(0.0, 1.0 / 128, (32, 16, 128)).astype(np.float32)
|
||||
|
Reference in New Issue
Block a user