mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
Reduce vmap + some fixes (#601)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
@@ -220,6 +220,50 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_vmap_reduce(self):
|
||||
a = mx.ones((5, 5), mx.int32)
|
||||
out = mx.vmap(lambda x: x.sum())(a)
|
||||
self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))
|
||||
|
||||
out = mx.vmap(lambda x: x.sum(keepdims=True))(a)
|
||||
self.assertTrue(mx.array_equal(out, mx.full((5, 1), 5)))
|
||||
|
||||
out = mx.vmap(lambda x: x.sum(axis=0))(a)
|
||||
self.assertTrue(mx.array_equal(out, mx.full((5,), 5)))
|
||||
|
||||
a = mx.ones((5, 3, 2), mx.int32)
|
||||
out = mx.vmap(lambda x: x.sum(axis=(0, 1)))(a)
|
||||
self.assertTrue(mx.array_equal(out, mx.full((5,), 6)))
|
||||
|
||||
a = mx.ones((5, 3, 2), mx.int32)
|
||||
out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(1,))(a)
|
||||
self.assertTrue(mx.array_equal(out, mx.full((3,), 10)))
|
||||
|
||||
a = mx.ones((5, 3, 2), mx.int32)
|
||||
out = mx.vmap(lambda x: x.sum(axis=(0, 1)), in_axes=(2,))(a)
|
||||
self.assertTrue(mx.array_equal(out, mx.full((2,), 15)))
|
||||
|
||||
def test_vmap_argreduce(self):
|
||||
a = mx.array([[1, 2, 3], [2, 3, 1]])
|
||||
out = mx.vmap(lambda x: mx.argmin(x))(a)
|
||||
expected = mx.array([0, 2])
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
out = mx.vmap(lambda x: mx.argmax(x))(a)
|
||||
expected = mx.array([2, 1])
|
||||
self.assertTrue(mx.array_equal(out, expected))
|
||||
|
||||
def test_mismatch_input_sizes(self):
|
||||
a = mx.ones((10, 1))
|
||||
b = mx.ones((1, 1, 1, 5))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
out = mx.vmap(lambda x, y: x + y)(a, b)
|
||||
|
||||
b = mx.ones((10, 5))
|
||||
with self.assertRaises(ValueError):
|
||||
out = mx.vmap(lambda x, y: x + y, in_axes=(0, 1))(a, b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user