mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	awni's commit files
This commit is contained in:
		
							
								
								
									
										1041
									
								
								python/tests/test_array.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1041
									
								
								python/tests/test_array.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										263
									
								
								python/tests/test_autograd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								python/tests/test_autograd.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,263 @@ | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestAutograd(mlx_tests.MLXTestCase): | ||||
|     def test_jvp(self): | ||||
|         fun = lambda x: 2 * x | ||||
|         out, dout = mx.jvp(fun, [mx.array(1.0)], [mx.array(2.0)]) | ||||
|         self.assertEqual(out[0].item(), 2.0) | ||||
|         self.assertEqual(dout[0].item(), 4.0) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         _, out = mx.jvp( | ||||
|             fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0), mx.array(2.0)] | ||||
|         ) | ||||
|         self.assertEqual(out[0].item(), 4.0 * 2.0 + 2.0 * 3.0) | ||||
|  | ||||
|         fun = lambda x, y, z: (x * y, y * z) | ||||
|         _, out = mx.jvp( | ||||
|             fun, | ||||
|             [mx.array(2.0), mx.array(4.0), mx.array(6.0)], | ||||
|             [mx.array(1.0), mx.array(3.0), mx.array(1.0)], | ||||
|         ) | ||||
|         self.assertEqual(len(out), 2) | ||||
|         self.assertEqual(out[0].item(), 4.0 * 1.0 + 2.0 * 3.0) | ||||
|         self.assertEqual(out[1].item(), 4.0 * 1.0 + 6.0 * 3.0) | ||||
|  | ||||
|     def test_vjp(self): | ||||
|         fun = lambda x: 2 * x | ||||
|         out, dout = mx.vjp(fun, [mx.array(1.0)], [mx.array(2.0)]) | ||||
|         self.assertEqual(out[0].item(), 2.0) | ||||
|         self.assertEqual(dout[0].item(), 4.0) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         _, dout = mx.vjp(fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0)]) | ||||
|         self.assertEqual(dout[0].item(), 6.0) | ||||
|         self.assertEqual(dout[1].item(), 12.0) | ||||
|  | ||||
|         fun = lambda x, y, z: (x * y, y * z) | ||||
|         _, out = mx.vjp( | ||||
|             fun, | ||||
|             [mx.array(2.0), mx.array(4.0), mx.array(6.0)], | ||||
|             [mx.array(1.0), mx.array(3.0)], | ||||
|         ) | ||||
|         self.assertEqual(len(out), 3) | ||||
|         self.assertEqual(out[0].item(), 4.0 * 1.0) | ||||
|         self.assertEqual(out[1].item(), 2.0 * 1.0 + 6.0 * 3.0) | ||||
|         self.assertEqual(out[2].item(), 4.0 * 3.0) | ||||
|  | ||||
|     def test_grad(self): | ||||
|         fun = lambda x: x * x | ||||
|  | ||||
|         value, dfdx = mx.value_and_grad(fun)(mx.array(0.5)) | ||||
|         self.assertEqual(value.item(), 0.25) | ||||
|         self.assertEqual(dfdx.item(), 1.0) | ||||
|  | ||||
|         dfdx = mx.grad(fun)(mx.array(0.5)) | ||||
|         self.assertEqual(dfdx.item(), 1.0) | ||||
|  | ||||
|         df2dx2 = mx.grad(mx.grad(fun))(mx.array(0.5)) | ||||
|         self.assertEqual(df2dx2.item(), 2.0) | ||||
|         df3dx3 = mx.grad(mx.grad(mx.grad(fun)))(mx.array(0.5)) | ||||
|         self.assertEqual(df3dx3.item(), 0.0) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         x = mx.array(2.0) | ||||
|         y = mx.array(3.0) | ||||
|         dfdx = mx.grad(fun, argnums=0)(x, y) | ||||
|         self.assertEqual(dfdx.item(), 3.0) | ||||
|         dfdx = mx.grad(fun, argnums=1)(x, y) | ||||
|         self.assertEqual(dfdx.item(), 2.0) | ||||
|  | ||||
|         # Pass non array args to functions works | ||||
|         fun = lambda x, y: x | ||||
|         value, dfdx = mx.value_and_grad(fun)(mx.array(2.0), "hello") | ||||
|         self.assertEqual(value.item(), 2.0) | ||||
|         self.assertEqual(dfdx.item(), 1.0) | ||||
|  | ||||
|         dfdx = mx.grad(fun)(mx.array(2.0), "hello") | ||||
|         self.assertEqual(dfdx.item(), 1.0) | ||||
|  | ||||
|         # Raises when function does not return array | ||||
|         fun = lambda x: "hello" | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun)(mx.array(2.0)) | ||||
|  | ||||
|         # Raises for invalid argument number or argument type | ||||
|         fun = lambda x: x | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun, argnums=2)(mx.array(2.0)) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun, argnums=-2)(mx.array(2.0)) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun)("hello") | ||||
|  | ||||
|         # Raises when output is not a scalar array | ||||
|         fun = lambda x: mx.sum(x, keepdims=True) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.grad(fun)(mx.ones((2, 2))) | ||||
|  | ||||
|     def test_grad_trees(self): | ||||
|         fun = lambda x, y: x * y | ||||
|         value, dfdx = mx.value_and_grad(fun, (0, 1))(mx.array(0.5), mx.array(2.0)) | ||||
|         self.assertEqual(value.item(), 1.0) | ||||
|         self.assertTrue(isinstance(dfdx, tuple)) | ||||
|         self.assertEqual(dfdx[0].item(), 2.0) | ||||
|         self.assertEqual(dfdx[1].item(), 0.5) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         value, dfdx = mx.value_and_grad(fun, 1)(mx.array(0.5), mx.array(2.0)) | ||||
|         self.assertEqual(value.item(), 1.0) | ||||
|         self.assertEqual(dfdx.item(), 0.5) | ||||
|  | ||||
|         fun = lambda p: p["x"] * p["y"] | ||||
|         value, dfdx = mx.value_and_grad(fun)({"x": mx.array(0.5), "y": mx.array(2.0)}) | ||||
|         self.assertEqual(value.item(), 1.0) | ||||
|         self.assertEqual(dfdx["x"].item(), 2.0) | ||||
|         self.assertEqual(dfdx["y"].item(), 0.5) | ||||
|  | ||||
|         fun = lambda p: p["x"] * p["y"] | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.value_and_grad(fun)({"x": 0.5, "y": mx.array(2.0)}) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.value_and_grad(fun, (0, 1))({"x": mx.array(0.5), "y": mx.array(2.0)}) | ||||
|  | ||||
|         fun = lambda p, b: mx.square(p[0]["foo"][2]) * b | ||||
|         value, dfdx = mx.value_and_grad(fun)( | ||||
|             [{"foo": [[], [], mx.array(2.0)]}], mx.array(0.5) | ||||
|         ) | ||||
|         self.assertEqual(value.item(), 2.0) | ||||
|         self.assertEqual(dfdx[0]["foo"][2].item(), 2.0) | ||||
|  | ||||
|         fun = lambda x: x | ||||
|         with self.assertRaises(TypeError): | ||||
|             mx.value_and_grad(fun, (None, None)) | ||||
|         with self.assertRaises(ValueError): | ||||
|             mx.value_and_grad(fun, tuple()) | ||||
|  | ||||
|     def test_auxiliary_values(self): | ||||
|         def fun(x, y): | ||||
|             l = (x * y).sum() | ||||
|             extra = {"loss": l, "foo": y.square() + x.square(), "bar": [1, 2, 3, y, x]} | ||||
|             return l, extra | ||||
|  | ||||
|         fun_value_grad = mx.value_and_grad(fun) | ||||
|         fun_grad = mx.grad(fun) | ||||
|  | ||||
|         (loss, a), b = fun_value_grad(mx.ones((2, 2)), mx.ones((2, 2))) | ||||
|         self.assertEqual(a["loss"].item(), 4) | ||||
|         self.assertTrue(mx.array_equal(b, mx.ones((2, 2)))) | ||||
|         self.assertTrue(mx.array_equal(a["foo"], 2 * mx.ones((2, 2)))) | ||||
|         self.assertEqual(a["bar"][:3], [1, 2, 3]) | ||||
|         self.assertTrue(mx.array_equal(a["bar"][3], mx.ones((2, 2)))) | ||||
|         self.assertTrue(mx.array_equal(a["bar"][4], mx.ones((2, 2)))) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             _ = fun_grad(mx.ones((2, 2)), mx.ones((2, 2))) | ||||
|  | ||||
|     def test_grad_kwargs(self): | ||||
|         fun = lambda x, y: x * y | ||||
|         a, b = mx.array(0.5), mx.array(2.0) | ||||
|         dfdx = mx.grad(fun) | ||||
|         self.assertEqual(dfdx(a, b).item(), 2.0) | ||||
|         self.assertEqual(dfdx(a, y=b).item(), 2.0) | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfdx(x=a, y=b).item() | ||||
|  | ||||
|         dfdy = mx.grad(fun, argnums=[], argnames=["y"]) | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfdy(a, b) | ||||
|         grads = dfdy(a, y=b) | ||||
|         self.assertTrue(isinstance(grads, tuple)) | ||||
|         self.assertTrue(grads[0] is None) | ||||
|         self.assertTrue(isinstance(grads[1], dict)) | ||||
|         self.assertEqual(grads[1]["y"].item(), 0.5) | ||||
|         grads = dfdy(x=a, y=b) | ||||
|         self.assertEqual(grads[1]["y"].item(), 0.5) | ||||
|         self.assertEqual(len(grads[1]), 1) | ||||
|  | ||||
|         dfdxy = mx.grad(fun, argnums=[0], argnames=["y"]) | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfdxy(a, b) | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfdxy(x=a, y=b) | ||||
|         grads = dfdxy(a, y=b) | ||||
|         self.assertTrue(isinstance(grads, tuple)) | ||||
|         self.assertEqual(grads[0].item(), 2.0) | ||||
|         self.assertTrue(isinstance(grads[1], dict)) | ||||
|         self.assertEqual(grads[1]["y"].item(), 0.5) | ||||
|  | ||||
|         fun = lambda x, y, z: x * y * z | ||||
|         dfdxyz = mx.grad(fun, argnums=[0, 1], argnames=["z"]) | ||||
|         c = mx.array(4.0) | ||||
|         grads = dfdxyz(a, b, z=c) | ||||
|         self.assertTrue(isinstance(grads, tuple)) | ||||
|         self.assertTrue(isinstance(grads[0], tuple)) | ||||
|         self.assertEqual(grads[0][0].item(), 8.0) | ||||
|         self.assertEqual(grads[0][1].item(), 2.0) | ||||
|         self.assertTrue(isinstance(grads[1], dict)) | ||||
|         self.assertEqual(grads[1]["z"].item(), 1.0) | ||||
|  | ||||
|         fun = lambda x, y: x * y | ||||
|         dfdy = mx.grad(fun, argnames=["y"]) | ||||
|         grads = dfdy(a, y=b) | ||||
|         self.assertTrue(isinstance(grads, tuple)) | ||||
|         self.assertTrue(grads[0] is None) | ||||
|         self.assertTrue(isinstance(grads[1], dict)) | ||||
|         self.assertEqual(grads[1]["y"].item(), 0.5) | ||||
|  | ||||
|     def test_captured(self): | ||||
|         a = mx.array(5.0) | ||||
|         f = lambda x: a + x | ||||
|         g = lambda x: a + a | ||||
|         h = lambda x: x + x | ||||
|  | ||||
|         dfdx = mx.grad(f) | ||||
|         self.assertEqual(dfdx(a).item(), 1.0) | ||||
|  | ||||
|         dgdx = mx.grad(g) | ||||
|         self.assertEqual(dgdx(a).item(), 0.0) | ||||
|  | ||||
|         dhdx = mx.grad(h) | ||||
|         self.assertEqual(dhdx(a).item(), 2.0) | ||||
|  | ||||
|         d2fdx2 = mx.grad(dfdx) | ||||
|         self.assertEqual(d2fdx2(a).item(), 0.0) | ||||
|  | ||||
|         d2gdx2 = mx.grad(dgdx) | ||||
|         self.assertEqual(d2gdx2(a).item(), 0.0) | ||||
|  | ||||
|         d2hdx2 = mx.grad(dhdx) | ||||
|         self.assertEqual(d2hdx2(a).item(), 0.0) | ||||
|  | ||||
|     def test_stop_gradient(self): | ||||
|         shape_in = (4, 4) | ||||
|         w_in = mx.ones(shape_in) | ||||
|         x_in = mx.ones(shape_in) | ||||
|         cotan = mx.ones(shape_in) | ||||
|  | ||||
|         def h(w, x): | ||||
|             x1 = 2 * x | ||||
|             y = mx.stop_gradient(x1) | ||||
|             y1 = 3 * y | ||||
|             return w @ y1 | ||||
|  | ||||
|         vals, vjps = mx.vjp(h, [w_in, x_in], [cotan]) | ||||
|         mx.eval(vjps) | ||||
|  | ||||
|         self.assertTrue(mx.allclose(vjps[0], 24.0 * mx.ones(shape_in))) | ||||
|         self.assertTrue(mx.allclose(vjps[1], mx.zeros(shape_in))) | ||||
|  | ||||
|         g = lambda x: h(w_in, x) | ||||
|         vals, vjps = mx.vjp(g, [x_in], [cotan]) | ||||
|         mx.eval(vjps) | ||||
|  | ||||
|         self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										105
									
								
								python/tests/test_device.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								python/tests/test_device.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| # Don't inherit from MLXTestCase to avoid call to setUp | ||||
| class TestDefaultDevice(unittest.TestCase): | ||||
|     def test_mlx_default_device(self): | ||||
|         device = mx.default_device() | ||||
|         if mx.metal.is_available(): | ||||
|             self.assertEqual(device, mx.Device(mx.gpu)) | ||||
|             self.assertEqual(str(device), "Device(gpu, 0)") | ||||
|             self.assertEqual(device, mx.gpu) | ||||
|             self.assertEqual(mx.gpu, device) | ||||
|         else: | ||||
|             self.assertEqual(device.type, mx.Device(mx.cpu)) | ||||
|             with self.assertRaises(ValueError): | ||||
|                 mx.set_default_device(mx.gpu) | ||||
|  | ||||
|  | ||||
| class TestDevice(mlx_tests.MLXTestCase): | ||||
|     def test_device(self): | ||||
|         device = mx.default_device() | ||||
|  | ||||
|         cpu = mx.Device(mx.cpu) | ||||
|         mx.set_default_device(cpu) | ||||
|         self.assertEqual(mx.default_device(), cpu) | ||||
|         self.assertEqual(str(cpu), "Device(cpu, 0)") | ||||
|  | ||||
|         mx.set_default_device(mx.cpu) | ||||
|         self.assertEqual(mx.default_device(), mx.cpu) | ||||
|         self.assertEqual(cpu, mx.cpu) | ||||
|         self.assertEqual(mx.cpu, cpu) | ||||
|  | ||||
|         # Restore device | ||||
|         mx.set_default_device(device) | ||||
|  | ||||
|     def test_op_on_device(self): | ||||
|         x = mx.array(1.0) | ||||
|         y = mx.array(1.0) | ||||
|  | ||||
|         a = mx.add(x, y, stream=None) | ||||
|         b = mx.add(x, y, stream=mx.default_device()) | ||||
|         self.assertEqual(a.item(), b.item()) | ||||
|         b = mx.add(x, y, stream=mx.cpu) | ||||
|         self.assertEqual(a.item(), b.item()) | ||||
|  | ||||
|         if mx.metal.is_available(): | ||||
|             b = mx.add(x, y, stream=mx.gpu) | ||||
|             self.assertEqual(a.item(), b.item()) | ||||
|  | ||||
|  | ||||
| class TestStream(mlx_tests.MLXTestCase): | ||||
|     def test_stream(self): | ||||
|         s1 = mx.default_stream(mx.default_device()) | ||||
|         self.assertEqual(s1.device, mx.default_device()) | ||||
|  | ||||
|         s2 = mx.new_stream(mx.default_device()) | ||||
|         self.assertEqual(s2.device, mx.default_device()) | ||||
|         self.assertNotEqual(s1, s2) | ||||
|  | ||||
|         if mx.metal.is_available(): | ||||
|             s_gpu = mx.default_stream(mx.gpu) | ||||
|             self.assertEqual(s_gpu.device, mx.gpu) | ||||
|         else: | ||||
|             with self.assertRaises(ValueError): | ||||
|                 mx.default_stream(mx.gpu) | ||||
|  | ||||
|         s_cpu = mx.default_stream(mx.cpu) | ||||
|         self.assertEqual(s_cpu.device, mx.cpu) | ||||
|  | ||||
|         s_cpu = mx.new_stream(mx.cpu) | ||||
|         self.assertEqual(s_cpu.device, mx.cpu) | ||||
|  | ||||
|         if mx.metal.is_available(): | ||||
|             s_gpu = mx.new_stream(mx.gpu) | ||||
|             self.assertEqual(s_gpu.device, mx.gpu) | ||||
|         else: | ||||
|             with self.assertRaises(ValueError): | ||||
|                 mx.new_stream(mx.gpu) | ||||
|  | ||||
|     def test_op_on_stream(self): | ||||
|         x = mx.array(1.0) | ||||
|         y = mx.array(1.0) | ||||
|  | ||||
|         a = mx.add(x, y, stream=mx.default_stream(mx.default_device())) | ||||
|  | ||||
|         if mx.metal.is_available(): | ||||
|             b = mx.add(x, y, stream=mx.default_stream(mx.gpu)) | ||||
|             self.assertEqual(a.item(), b.item()) | ||||
|             s_gpu = mx.new_stream(mx.gpu) | ||||
|             b = mx.add(x, y, stream=s_gpu) | ||||
|             self.assertEqual(a.item(), b.item()) | ||||
|  | ||||
|         b = mx.add(x, y, stream=mx.default_stream(mx.cpu)) | ||||
|         self.assertEqual(a.item(), b.item()) | ||||
|         s_cpu = mx.new_stream(mx.cpu) | ||||
|         b = mx.add(x, y, stream=s_cpu) | ||||
|         self.assertEqual(a.item(), b.item()) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										34
									
								
								python/tests/test_eval.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								python/tests/test_eval.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | ||||
| from functools import partial | ||||
|  | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestEval(mlx_tests.MLXTestCase): | ||||
|     def test_eval(self): | ||||
|         arrs = [mx.ones((2, 2)) for _ in range(4)] | ||||
|         mx.eval(*arrs) | ||||
|         for x in arrs: | ||||
|             self.assertEqual(x.tolist(), [[1, 1], [1, 1]]) | ||||
|  | ||||
|     def test_retain_graph(self): | ||||
|         def fun(x, retain_graph): | ||||
|             y = 3 * x | ||||
|             mx.eval(y, retain_graph=retain_graph) | ||||
|             return 2 * y | ||||
|  | ||||
|         dfun_dx_1 = mx.grad(partial(fun, retain_graph=False)) | ||||
|         dfun_dx_2 = mx.grad(partial(fun, retain_graph=True)) | ||||
|  | ||||
|         with self.assertRaises(ValueError): | ||||
|             dfun_dx_1(mx.array(1.0)) | ||||
|  | ||||
|         y = dfun_dx_2(mx.array(1.0)) | ||||
|         self.assertEqual(y.item(), 6.0) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										90
									
								
								python/tests/test_fft.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								python/tests/test_fft.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,90 @@ | ||||
| import unittest | ||||
|  | ||||
| import itertools | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestFFT(mlx_tests.MLXTestCase): | ||||
|     def check_mx_np(self, op, a_np, axes, s): | ||||
|         with self.subTest(op=op, axes=axes, s=s): | ||||
|             op_np = getattr(np.fft, op) | ||||
|             op_mx = getattr(mx.fft, op) | ||||
|             out_np = op_np(a_np, s=s, axes=axes) | ||||
|             a_mx = mx.array(a_np) | ||||
|             out_mx = op_mx(a_mx, s=s, axes=axes) | ||||
|             self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) | ||||
|  | ||||
|     def test_fft(self): | ||||
|         default = mx.default_device() | ||||
|         mx.set_default_device(mx.cpu) | ||||
|  | ||||
|         def check_mx_np(op_mx, op_np, a_np, **kwargs): | ||||
|             out_np = op_np(a_np, **kwargs) | ||||
|             a_mx = mx.array(a_np) | ||||
|             out_mx = op_mx(a_mx, **kwargs) | ||||
|             self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) | ||||
|  | ||||
|         r = np.random.rand(100).astype(np.float32) | ||||
|         i = np.random.rand(100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np) | ||||
|  | ||||
|         # Check with slicing and padding | ||||
|         r = np.random.rand(100).astype(np.float32) | ||||
|         i = np.random.rand(100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) | ||||
|  | ||||
|         # Check different axes | ||||
|         r = np.random.rand(100, 100).astype(np.float32) | ||||
|         i = np.random.rand(100, 100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) | ||||
|         check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) | ||||
|  | ||||
|         # Check real fft | ||||
|         a_np = np.random.rand(100).astype(np.float32) | ||||
|         check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) | ||||
|         check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) | ||||
|  | ||||
|         # Check real inverse | ||||
|         r = np.random.rand(100, 100).astype(np.float32) | ||||
|         i = np.random.rand(100, 100).astype(np.float32) | ||||
|         a_np = r + 1j * i | ||||
|         check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) | ||||
|         check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) | ||||
|         check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) | ||||
|         check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) | ||||
|         check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) | ||||
|  | ||||
|         mx.set_default_device(default) | ||||
|  | ||||
|     def test_fftn(self): | ||||
|         default = mx.default_device() | ||||
|         mx.set_default_device(mx.cpu) | ||||
|  | ||||
|         r = np.random.randn(8, 8, 8).astype(np.float32) | ||||
|         i = np.random.randn(8, 8, 8).astype(np.float32) | ||||
|         a = r + 1j * i | ||||
|  | ||||
|         axes = [None, (1, 2), (2, 1), (0, 2)] | ||||
|         shapes = [None, (10, 5), (5, 10)] | ||||
|         ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"] | ||||
|  | ||||
|         for op, ax, s in itertools.product(ops, axes, shapes): | ||||
|             x = a | ||||
|             if op in ["rfft2", "rfftn"]: | ||||
|                 x = r | ||||
|             self.check_mx_np(op, x, axes=ax, s=s) | ||||
|  | ||||
|         mx.set_default_device(default) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										1283
									
								
								python/tests/test_ops.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1283
									
								
								python/tests/test_ops.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										118
									
								
								python/tests/test_reduce.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										118
									
								
								python/tests/test_reduce.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,118 @@ | ||||
| import unittest | ||||
| from itertools import permutations, combinations | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
|  | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| class TestReduce(mlx_tests.MLXTestCase): | ||||
|     def test_axis_permutation_sums(self): | ||||
|         x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32) | ||||
|         x_mlx = mx.array(x_npy) | ||||
|         for t in permutations(range(5)): | ||||
|             with self.subTest(t=t): | ||||
|                 y_npy = np.transpose(x_npy, t) | ||||
|                 y_mlx = mx.transpose(x_mlx, t) | ||||
|                 for n in range(1, 6): | ||||
|                     for a in combinations(range(5), n): | ||||
|                         with self.subTest(a=a): | ||||
|                             z_npy = np.sum(y_npy, axis=a) | ||||
|                             z_mlx = mx.sum(y_mlx, axis=a) | ||||
|                             mx.eval(z_mlx) | ||||
|                             self.assertTrue( | ||||
|                                 np.allclose(z_npy, np.array(z_mlx), atol=1e-4) | ||||
|                             ) | ||||
|  | ||||
|     def test_expand_sums(self): | ||||
|         x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32) | ||||
|         x_mlx = mx.array(x_npy) | ||||
|         for m in range(1, 4): | ||||
|             for ax in combinations([1, 3, 5], m): | ||||
|                 shape = np.array([5, 1, 5, 1, 5, 1]) | ||||
|                 shape[list(ax)] = 5 | ||||
|                 shape = shape.tolist() | ||||
|                 with self.subTest(shape=shape): | ||||
|                     y_npy = np.broadcast_to(x_npy, shape) | ||||
|                     y_mlx = mx.broadcast_to(x_mlx, shape) | ||||
|                     for n in range(1, 7): | ||||
|                         for a in combinations(range(6), n): | ||||
|                             with self.subTest(a=a): | ||||
|                                 z_npy = np.sum(y_npy, axis=a) / 1000 | ||||
|                                 z_mlx = mx.sum(y_mlx, axis=a) / 1000 | ||||
|                                 mx.eval(z_mlx) | ||||
|                                 self.assertTrue( | ||||
|                                     np.allclose(z_npy, np.array(z_mlx), atol=1e-4) | ||||
|                                 ) | ||||
|  | ||||
|     def test_dtypes(self): | ||||
|         int_dtypes = [ | ||||
|             "int8", | ||||
|             "int16", | ||||
|             "int32", | ||||
|             "uint8", | ||||
|             "uint16", | ||||
|             "uint32", | ||||
|         ] | ||||
|         float_dtypes = ["float32"] | ||||
|  | ||||
|         for dtype in int_dtypes + float_dtypes: | ||||
|             with self.subTest(dtype=dtype): | ||||
|                 x = np.random.uniform(0, 2, size=(3, 3, 3)).astype(getattr(np, dtype)) | ||||
|                 y = mx.array(x) | ||||
|  | ||||
|                 for op in ("sum", "prod", "min", "max"): | ||||
|                     with self.subTest(op=op): | ||||
|  | ||||
|                         np_op = getattr(np, op) | ||||
|                         mlx_op = getattr(mx, op) | ||||
|  | ||||
|                         for axes in (None, 0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)): | ||||
|                             with self.subTest(axes=axes): | ||||
|                                 if op in ("sum", "prod"): | ||||
|                                     r_np = np_op( | ||||
|                                         x, axis=axes, dtype=(getattr(np, dtype)) | ||||
|                                     ) | ||||
|                                 else: | ||||
|                                     r_np = np_op(x, axis=axes) | ||||
|                                 r_mlx = mlx_op(y, axis=axes) | ||||
|                                 mx.eval(r_mlx) | ||||
|                                 self.assertTrue(np.allclose(r_np, r_mlx, atol=1e-4)) | ||||
|  | ||||
|     def test_arg_reduce(self): | ||||
|         dtypes = [ | ||||
|             "uint8", | ||||
|             "uint16", | ||||
|             "uint32", | ||||
|             "uint64", | ||||
|             "int8", | ||||
|             "int16", | ||||
|             "int32", | ||||
|             "int64", | ||||
|             "float16", | ||||
|             "float32", | ||||
|         ] | ||||
|         for dtype in dtypes: | ||||
|             with self.subTest(dtype=dtype): | ||||
|  | ||||
|                 data = np.random.rand(10, 12, 13).astype(getattr(np, dtype)) | ||||
|                 x = mx.array(data) | ||||
|                 for op in ["argmin", "argmax"]: | ||||
|                     for axis in range(3): | ||||
|                         for kd in [True, False]: | ||||
|                             a = getattr(mx, op)(x, axis, kd) | ||||
|                             b = getattr(np, op)(data, axis, keepdims=kd) | ||||
|                             self.assertEqual(a.tolist(), b.tolist()) | ||||
|  | ||||
|                 for op in ["argmin", "argmax"]: | ||||
|                     a = getattr(mx, op)(x, keepdims=True) | ||||
|                     b = getattr(np, op)(data, keepdims=True) | ||||
|                     self.assertEqual(a.tolist(), b.tolist()) | ||||
|                     a = getattr(mx, op)(x) | ||||
|                     b = getattr(np, op)(data) | ||||
|                     self.assertEqual(a.item(), b) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main(failfast=True) | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun