mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 07:14:34 +08:00
NumberOfElements for shapeless compile and vmap fixes (#802)
This commit is contained in:

committed by
GitHub

parent
29d0c10ee5
commit
76c919b4ec
@@ -253,6 +253,17 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([2, 1])
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_vmap_mean(self):
|
||||
a = mx.arange(8).reshape(2, 4)
|
||||
out = mx.vmap(mx.mean)(a)
|
||||
expected = mx.mean(a, axis=1)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
a = mx.arange(16).reshape(2, 2, 4)
|
||||
out = mx.vmap(mx.vmap(mx.mean))(a)
|
||||
expected = mx.mean(a, axis=2)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_mismatch_input_sizes(self):
|
||||
a = mx.ones((10, 1))
|
||||
b = mx.ones((1, 1, 1, 5))
|
||||
|
Reference in New Issue
Block a user