mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 03:18:12 +08:00
NumberOfElements for shapeless compile and vmap fixes (#802)
This commit is contained in:
committed by
GitHub
parent
29d0c10ee5
commit
76c919b4ec
@@ -653,6 +653,24 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(expected[0], out[0]))
|
||||
self.assertTrue(mx.allclose(expected[1], out[1]))
|
||||
|
||||
def test_shapeless_mean(self):
|
||||
def mean(x):
|
||||
return mx.mean(x, keepdims=True)
|
||||
|
||||
cmean = mx.compile(mean, shapeless=True)
|
||||
|
||||
x = mx.ones(2)
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
x = mx.ones(4)
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
x = mx.ones(7)
|
||||
out = cmean(x)
|
||||
self.assertTrue(mx.allclose(out, mean(x)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user