mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			169 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			169 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright © 2023 Apple Inc.
 | |
| 
 | |
| import unittest
 | |
| 
 | |
| import mlx.core as mx
 | |
| import mlx_tests
 | |
| 
 | |
| 
 | |
| class TestVmap(mlx_tests.MLXTestCase):
 | |
|     def test_basics(self):
 | |
|         # Can't vmap over scalars
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(mx.exp)(mx.array(1.0))
 | |
| 
 | |
|         # Invalid input
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(mx.exp)("hello")
 | |
| 
 | |
|         # Invalid axes
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(mx.exp, in_axes="hello")(mx.array([0, 1]))
 | |
| 
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(mx.exp, in_axes=2)(mx.array([0, 1]))
 | |
| 
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(mx.exp, out_axes="hello")(mx.array([0, 1]))
 | |
| 
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(mx.exp, out_axes=2)(mx.array([0, 1]))
 | |
| 
 | |
|     def test_unary(self):
 | |
|         ops = [
 | |
|             "abs",
 | |
|             "cos",
 | |
|             "erf",
 | |
|             "erfinv",
 | |
|             "exp",
 | |
|             "log",
 | |
|             "log1p",
 | |
|             "log2",
 | |
|             "log10",
 | |
|             "logical_not",
 | |
|             "negative",
 | |
|             "reciprocal",
 | |
|             "rsqrt",
 | |
|             "sigmoid",
 | |
|             "sign",
 | |
|             "sin",
 | |
|             "sqrt",
 | |
|             "square",
 | |
|         ]
 | |
|         ops = ["erfinv"]
 | |
|         for opname in ops:
 | |
|             with self.subTest(op=opname):
 | |
|                 op = getattr(mx, opname)
 | |
|                 x = mx.arange(5)
 | |
|                 y = mx.vmap(op)(x)
 | |
|                 self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
 | |
| 
 | |
|                 x = mx.arange(8).reshape(2, 4)
 | |
|                 y = mx.vmap(op)(x)
 | |
|                 self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
 | |
| 
 | |
|                 y = mx.vmap(op, in_axes=1, out_axes=1)(x)
 | |
|                 self.assertTrue(mx.array_equal(y, op(x), equal_nan=True))
 | |
| 
 | |
|     def test_binary(self):
 | |
|         ops = [
 | |
|             "add",
 | |
|             "divide",
 | |
|             "equal",
 | |
|             "greater",
 | |
|             "greater_equal",
 | |
|             "less",
 | |
|             "less_equal",
 | |
|             "logaddexp",
 | |
|             "maximum",
 | |
|             "minimum",
 | |
|             "multiply",
 | |
|             "power",
 | |
|             "subtract",
 | |
|         ]
 | |
|         for opname in ops:
 | |
|             with self.subTest(op=opname):
 | |
|                 op = getattr(mx, opname)
 | |
|                 x = mx.random.uniform(shape=(5,))
 | |
|                 y = mx.random.uniform(shape=(5,))
 | |
|                 out = mx.vmap(op)(x, y)
 | |
|                 self.assertTrue(mx.array_equal(out, op(x, y)))
 | |
| 
 | |
|                 x = mx.random.uniform(shape=(2, 4))
 | |
|                 y = mx.random.uniform(shape=(2, 4))
 | |
|                 out = mx.vmap(op)(x, y)
 | |
|                 self.assertTrue(mx.array_equal(out, op(x, y)))
 | |
| 
 | |
|                 out = mx.vmap(op, in_axes=(0, 0), out_axes=0)(x, y)
 | |
|                 self.assertTrue(mx.array_equal(out, op(x, y)))
 | |
| 
 | |
|                 y = mx.random.uniform(shape=(4, 2))
 | |
|                 out = mx.vmap(op, in_axes=(0, 1), out_axes=0)(x, y)
 | |
|                 self.assertTrue(mx.array_equal(out, op(x, y.T)))
 | |
| 
 | |
|                 out = mx.vmap(op, in_axes=(0, 1), out_axes=1)(x, y)
 | |
|                 self.assertTrue(mx.array_equal(out, op(x, y.T).T))
 | |
| 
 | |
|     def test_tree(self):
 | |
|         def my_fun(tree):
 | |
|             return (tree["a"] + tree["b"][0]) * tree["b"][1]
 | |
| 
 | |
|         tree = {
 | |
|             "a": mx.random.uniform(shape=(2, 4)),
 | |
|             "b": (
 | |
|                 mx.random.uniform(shape=(2, 4)),
 | |
|                 mx.random.uniform(shape=(2, 4)),
 | |
|             ),
 | |
|         }
 | |
|         out = mx.vmap(my_fun)(tree)
 | |
|         expected = my_fun(tree)
 | |
|         self.assertTrue(mx.array_equal(out, my_fun(tree)))
 | |
| 
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(my_fun, in_axes={"a": 0, "b": 0}, out_axes=0)(tree)
 | |
| 
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(my_fun, in_axes={"a": 0, "b": ((0, 0), 0)}, out_axes=0)(tree)
 | |
| 
 | |
|         out = mx.vmap(my_fun, in_axes=({"a": 0, "b": 0},), out_axes=0)(tree)
 | |
|         self.assertTrue(mx.array_equal(out, my_fun(tree)))
 | |
| 
 | |
|         out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (0, 0)},), out_axes=0)(tree)
 | |
|         self.assertTrue(mx.array_equal(out, my_fun(tree)))
 | |
| 
 | |
|         tree = {
 | |
|             "a": mx.random.uniform(shape=(2, 4)),
 | |
|             "b": (
 | |
|                 mx.random.uniform(shape=(4, 2)),
 | |
|                 mx.random.uniform(shape=(4, 2)),
 | |
|             ),
 | |
|         }
 | |
|         out = mx.vmap(my_fun, in_axes=({"a": 0, "b": (1, 1)},), out_axes=0)(tree)
 | |
|         expected = (tree["a"] + tree["b"][0].T) * tree["b"][1].T
 | |
|         self.assertTrue(mx.array_equal(out, expected))
 | |
| 
 | |
|         def my_fun(x, y):
 | |
|             return {"a": x + y, "b": x * y}
 | |
| 
 | |
|         x = mx.random.uniform(shape=(2, 4))
 | |
|         y = mx.random.uniform(shape=(2, 4))
 | |
|         out = mx.vmap(my_fun, in_axes=0, out_axes=0)(x, y)
 | |
|         expected = my_fun(x, y)
 | |
|         self.assertTrue(mx.array_equal(out["a"], expected["a"]))
 | |
|         self.assertTrue(mx.array_equal(out["b"], expected["b"]))
 | |
| 
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(my_fun, in_axes=0, out_axes=(0, 1))(x, y)
 | |
| 
 | |
|         with self.assertRaises(ValueError):
 | |
|             mx.vmap(my_fun, in_axes=0, out_axes={"a": 0, "c": 1})(x, y)
 | |
| 
 | |
|         out = mx.vmap(my_fun, in_axes=0, out_axes={"a": 1, "b": 0})(x, y)
 | |
|         expected = my_fun(x, y)
 | |
|         self.assertTrue(mx.array_equal(out["a"].T, expected["a"]))
 | |
|         self.assertTrue(mx.array_equal(out["b"], expected["b"]))
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     unittest.main()
 | 
