Reduce vmap + some fixes (#601)

This commit is contained in:
Awni Hannun
2024-02-01 11:30:28 -08:00
committed by GitHub
parent 601c6d6aa8
commit e88e474fd1
5 changed files with 161 additions and 33 deletions

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import unittest
@@ -220,6 +220,50 @@ class TestVmap(mlx_tests.MLXTestCase):
)
self.assertTrue(mx.array_equal(out, expected))
def test_vmap_reduce(self):
a = mx.ones((5, 5), mx.int32)
out = mx.vmap(lambda x: x.sum())(a)
self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))
out = mx.vmap(lambda x: x.sum(keepdims=True))(a)
self.assertTrue(mx.array_equal(out, mx.full((5, 1), 5)))
out = mx.vmap(lambda x: x.sum(axis=0))(a)
self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))
a = mx.ones((5, 3, 2), mx.int32)
out = mx.vmap(lambda x: x.sum(axis=(0, 1)))(a)
self.assertTrue(mx.array_equal(out, mx.full((5,), 6)))
a = mx.ones((5, 3, 2), mx.int32)
out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(1,))(a)
self.assertTrue(mx.array_equal(out, mx.full((3,), 10)))
a = mx.ones((5, 3, 2), mx.int32)
out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(2,))(a)
self.assertTrue(mx.array_equal(out, mx.full((2,), 15)))
def test_vmap_argreduce(self):
a = mx.array([[1, 2, 3], [2, 3, 1]])
out = mx.vmap(lambda x: mx.argmin(x))(a)
expected = mx.array([0, 2])
self.assertTrue(mx.array_equal(out, expected))
out = mx.vmap(lambda x: mx.argmax(x))(a)
expected = mx.array([2, 1])
self.assertTrue(mx.array_equal(out, expected))
def test_mismatch_input_sizes(self):
a = mx.ones((10, 1))
b = mx.ones((1, 1, 1, 5))
with self.assertRaises(ValueError):
out = mx.vmap(lambda x, y: x + y)(a, b)
b = mx.ones((10, 5))
with self.assertRaises(ValueError):
out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)
if __name__ == "__main__":
unittest.main()