diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index b2a762681..ffbbbb0a7 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -191,7 +191,7 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertListEqual(list(c_npy.shape), list(c_mlx.shape)) self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6)) - # Batched matmul with simple broadast + # Batched matmul with simple broadcast a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32) c_npy = a_npy @ b_npy @@ -213,7 +213,7 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertListEqual(list(e_npy.shape), list(e_mlx.shape)) self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6)) - # Batched and transposed matmul with simple broadast + # Batched and transposed matmul with simple broadcast a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32) b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32) a_mlx = mx.array(a_npy)