mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
spelling: broadcast
Signed-off-by: Josh Soref <2119212+jsoref@users.noreply.github.com>
This commit is contained in:
parent
94f71e5832
commit
16046082f6
@ -191,7 +191,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
self.assertListEqual(list(c_npy.shape), list(c_mlx.shape))
|
||||||
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
self.assertTrue(np.allclose(c_mlx, c_npy, atol=1e-6))
|
||||||
|
|
||||||
# Batched matmul with simple broadast
|
# Batched matmul with simple broadcast
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
b_npy = np.random.normal(0.0, 1.0 / 128, (16, 16)).astype(np.float32)
|
||||||
c_npy = a_npy @ b_npy
|
c_npy = a_npy @ b_npy
|
||||||
@ -213,7 +213,7 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
|
self.assertListEqual(list(e_npy.shape), list(e_mlx.shape))
|
||||||
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
|
self.assertTrue(np.allclose(e_mlx, e_npy, atol=1e-6))
|
||||||
|
|
||||||
# Batched and transposed matmul with simple broadast
|
# Batched and transposed matmul with simple broadcast
|
||||||
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
a_npy = np.random.normal(0.0, 1.0 / 128, (32, 128, 16)).astype(np.float32)
|
||||||
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
b_npy = np.random.normal(0.0, 1.0 / 128, (128, 16)).astype(np.float32)
|
||||||
a_mlx = mx.array(a_npy)
|
a_mlx = mx.array(a_npy)
|
||||||
|
Loading…
Reference in New Issue
Block a user