mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	simple numpy helper for tests (#352)
This commit is contained in:
		| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| import os | ||||
| import unittest | ||||
| from typing import Any, Callable, List, Tuple | ||||
|  | ||||
| import mlx.core as mx | ||||
| import numpy as np | ||||
| @@ -18,13 +19,33 @@ class MLXTestCase(unittest.TestCase): | ||||
|     def tearDown(self): | ||||
|         mx.set_default_device(self.default) | ||||
|  | ||||
|     def assertCmpNumpy( | ||||
|         self, | ||||
|         shape: List[Tuple[int] | Any], | ||||
|         mx_fn: Callable[..., mx.array], | ||||
|         np_fn: Callable[..., np.array], | ||||
|         atol=1e-2, | ||||
|         rtol=1e-2, | ||||
|         dtype=mx.float32, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         assert dtype != mx.bfloat16, "numpy does not support bfloat16" | ||||
|         args = [ | ||||
|             mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s | ||||
|             for s in shape | ||||
|         ] | ||||
|         mx_res = mx_fn(*args, **kwargs) | ||||
|         np_res = np_fn( | ||||
|             *[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs | ||||
|         ) | ||||
|         return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol) | ||||
|  | ||||
|     def assertEqualArray( | ||||
|         self, | ||||
|         mx_res: mx.array, | ||||
|         expected: mx.array, | ||||
|         atol=1e-2, | ||||
|         rtol=1e-2, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         assert tuple(mx_res.shape) == tuple( | ||||
|             expected.shape | ||||
|   | ||||
| @@ -324,24 +324,18 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|     def test_tri(self): | ||||
|         for shape in [[4], [4, 4], [2, 10]]: | ||||
|             for diag in [-1, 0, 1, -2]: | ||||
|                 self.assertEqualArray( | ||||
|                     mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag)) | ||||
|                 ) | ||||
|                 self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag) | ||||
|  | ||||
|     def test_tril(self): | ||||
|         mt = mx.random.normal((10, 10)) | ||||
|         nt = np.array(mt) | ||||
|         for diag in [-1, 0, 1, -2]: | ||||
|             self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag))) | ||||
|             self.assertCmpNumpy([(10, 10)], mx.tril, np.tril, k=diag) | ||||
|  | ||||
|         with self.assertRaises(Exception): | ||||
|             mx.tril(mx.zeros((1))) | ||||
|  | ||||
|     def test_triu(self): | ||||
|         mt = mx.random.normal((10, 10)) | ||||
|         nt = np.array(mt) | ||||
|         for diag in [-1, 0, 1, -2]: | ||||
|             self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag))) | ||||
|             self.assertCmpNumpy([(10, 10)], mx.triu, np.triu, k=diag) | ||||
|         with self.assertRaises(Exception): | ||||
|             mx.triu(mx.zeros((1))) | ||||
|  | ||||
| @@ -1260,20 +1254,17 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         self.assertTrue(mx.allclose(a_bwd[4:-2, 2:-4], df[0]).item()) | ||||
|  | ||||
|     def test_where(self): | ||||
|         a = mx.array([[1, 2], [3, 4]]) | ||||
|         out = mx.where(True, a, 1) | ||||
|         out_np = np.where(True, a, 1) | ||||
|         self.assertTrue(np.array_equal(out, out_np)) | ||||
|  | ||||
|         out = mx.where(True, 1, a) | ||||
|         out_np = np.where(True, 1, a) | ||||
|         self.assertTrue(np.array_equal(out, out_np)) | ||||
|  | ||||
|         condition = mx.array([[True, False], [False, True]]) | ||||
|         b = mx.array([5, 6]) | ||||
|         out = mx.where(condition, a, b) | ||||
|         out_np = np.where(condition, a, b) | ||||
|         self.assertTrue(np.array_equal(out, out_np)) | ||||
|         self.assertCmpNumpy([True, mx.array([[1, 2], [3, 4]]), 1], mx.where, np.where) | ||||
|         self.assertCmpNumpy([True, 1, mx.array([[1, 2], [3, 4]])], mx.where, np.where) | ||||
|         self.assertCmpNumpy( | ||||
|             [ | ||||
|                 mx.array([[True, False], [False, True]]), | ||||
|                 mx.array([[1, 2], [3, 4]]), | ||||
|                 mx.array([5, 6]), | ||||
|             ], | ||||
|             mx.where, | ||||
|             np.where, | ||||
|         ) | ||||
|  | ||||
|     def test_as_strided(self): | ||||
|         x_npy = np.random.randn(128).astype(np.float32) | ||||
| @@ -1408,24 +1399,13 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual((a + b)[0, 0].item(), 2) | ||||
|  | ||||
|     def test_eye(self): | ||||
|         eye_matrix = mx.eye(3) | ||||
|         np_eye_matrix = np.eye(3) | ||||
|         self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) | ||||
|  | ||||
|         self.assertCmpNumpy([3], mx.eye, np.eye) | ||||
|         # Test for non-square matrix | ||||
|         eye_matrix = mx.eye(3, 4) | ||||
|         np_eye_matrix = np.eye(3, 4) | ||||
|         self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) | ||||
|  | ||||
|         self.assertCmpNumpy([3, 4], mx.eye, np.eye) | ||||
|         # Test with positive k parameter | ||||
|         eye_matrix = mx.eye(3, 4, k=1) | ||||
|         np_eye_matrix = np.eye(3, 4, k=1) | ||||
|         self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) | ||||
|  | ||||
|         self.assertCmpNumpy([3, 4], mx.eye, np.eye, k=1) | ||||
|         # Test with negative k parameter | ||||
|         eye_matrix = mx.eye(5, 6, k=-2) | ||||
|         np_eye_matrix = np.eye(5, 6, k=-2) | ||||
|         self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) | ||||
|         self.assertCmpNumpy([5, 6], mx.eye, np.eye, k=-2) | ||||
|  | ||||
|     def test_stack(self): | ||||
|         a = mx.ones((2,)) | ||||
| @@ -1518,50 +1498,47 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|  | ||||
|     def test_repeat(self): | ||||
|         # Setup data for the tests | ||||
|         data = np.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) | ||||
|         data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) | ||||
|         # Test repeat along axis 0 | ||||
|         repeat_axis_0 = mx.repeat(mx.array(data), 2, axis=0) | ||||
|         expected_axis_0 = np.repeat(data, 2, axis=0) | ||||
|  | ||||
|         self.assertEqualArray(repeat_axis_0, mx.array(expected_axis_0)) | ||||
|  | ||||
|         self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0) | ||||
|         # Test repeat along axis 1 | ||||
|         repeat_axis_1 = mx.repeat(mx.array(data), 2, axis=1) | ||||
|         expected_axis_1 = np.repeat(data, 2, axis=1) | ||||
|         self.assertEqualArray(repeat_axis_1, mx.array(expected_axis_1)) | ||||
|  | ||||
|         self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=1) | ||||
|         # Test repeat along the last axis (default) | ||||
|         repeat_axis_2 = mx.repeat(mx.array(data), 2) | ||||
|         expected_axis_2 = np.repeat(data, 2) | ||||
|         self.assertEqualArray(repeat_axis_2, mx.array(expected_axis_2)) | ||||
|  | ||||
|         self.assertCmpNumpy([data, 2], mx.repeat, np.repeat) | ||||
|         # Test repeat with a 1D array along axis 0 | ||||
|         data_2 = mx.array([1, 3, 2]) | ||||
|         repeat_2 = mx.repeat(mx.array(data_2), 3, axis=0) | ||||
|         expected_2 = np.repeat(data_2, 3) | ||||
|         self.assertEqualArray(repeat_2, mx.array(expected_2)) | ||||
|  | ||||
|         self.assertCmpNumpy([mx.array([1, 3, 2]), 3], mx.repeat, np.repeat, axis=0) | ||||
|         # Test repeat with a 2D array along axis 0 | ||||
|         data_3 = mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]]) | ||||
|         repeat_3 = mx.repeat(mx.array(data_3), 2, axis=0) | ||||
|         expected_3 = np.repeat(data_3, 2, axis=0) | ||||
|         self.assertEqualArray(repeat_3, mx.array(expected_3)) | ||||
|         self.assertCmpNumpy( | ||||
|             [mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]]), 2], | ||||
|             mx.repeat, | ||||
|             np.repeat, | ||||
|             axis=0, | ||||
|         ) | ||||
|  | ||||
|     def test_tensordot(self): | ||||
|         x = mx.arange(60.0).reshape(3, 4, 5) | ||||
|         y = mx.arange(24.0).reshape(4, 3, 2) | ||||
|         z = mx.tensordot(x, y, dims=([1, 0], [0, 1])) | ||||
|         self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1])))) | ||||
|         x = mx.random.normal((3, 4, 5)) | ||||
|         y = mx.random.normal((4, 5, 6)) | ||||
|         z = mx.tensordot(x, y, dims=2) | ||||
|         self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2))) | ||||
|         x = mx.random.normal((3, 5, 4, 6)) | ||||
|         y = mx.random.normal((6, 4, 5, 3)) | ||||
|         z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0])) | ||||
|         self.assertEqualArray( | ||||
|             z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0]))) | ||||
|         ) | ||||
|         for dtype in [mx.float16, mx.float32]: | ||||
|             with self.subTest(dtype=dtype): | ||||
|                 self.assertCmpNumpy( | ||||
|                     [(3, 4, 5), (4, 3, 2)], | ||||
|                     mx.tensordot, | ||||
|                     lambda x, y, dims: np.tensordot(x, y, axes=dims), | ||||
|                     dtype=dtype, | ||||
|                     dims=([1, 0], [0, 1]), | ||||
|                 ) | ||||
|                 self.assertCmpNumpy( | ||||
|                     [(3, 4, 5), (4, 5, 6)], | ||||
|                     mx.tensordot, | ||||
|                     lambda x, y, dims: np.tensordot(x, y, axes=dims), | ||||
|                     dtype=dtype, | ||||
|                     dims=2, | ||||
|                 ) | ||||
|                 self.assertCmpNumpy( | ||||
|                     [(3, 5, 4, 6), (6, 4, 5, 3)], | ||||
|                     mx.tensordot, | ||||
|                     lambda x, y, dims: np.tensordot(x, y, axes=dims), | ||||
|                     dtype=dtype, | ||||
|                     dims=([2, 1, 3], [1, 2, 0]), | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Diogo
					Diogo