mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
NumberOfElements for shapeless compile and vmap fixes (#802)
This commit is contained in:

committed by
GitHub

parent
29d0c10ee5
commit
76c919b4ec
@@ -653,6 +653,24 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(expected[0], out[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], out[1]))
|
||||
|
||||
def test_shapeless_mean(self):
|
||||
def mean(x):
|
||||
return mx.mean(x, keepdims=True)
|
||||
|
||||
cmean = mx.compile(mean, shapeless=True)
|
||||
|
||||
x = mx.ones(2)
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
x = mx.ones(4)
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
x = mx.ones(7)
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -253,6 +253,17 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
expected = mx.array([2, 1])
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_vmap_mean(self):
|
||||
a = mx.arange(8).reshape(2, 4)
|
||||
out = mx.vmap(mx.mean)(a)
|
||||
expected = mx.mean(a, axis=1)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
a = mx.arange(16).reshape(2, 2, 4)
|
||||
out = mx.vmap(mx.vmap(mx.mean))(a)
|
||||
expected = mx.mean(a, axis=2)
|
||||
self.assertTrue(mx.allclose(out, expected))
|
||||
|
||||
def test_mismatch_input_sizes(self):
|
||||
a = mx.ones((10, 1))
|
||||
b = mx.ones((1, 1, 1, 5))
|
||||
|
Reference in New Issue
Block a user