mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	added atleast *args input support (#710)
* added atleast list(array) input support * function overloading implemented * Refactoring * fixed formatting * removed pos_only
This commit is contained in:
		 Hinrik Snær Guðmundsson
					Hinrik Snær Guðmundsson
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							3b661b7394
						
					
				
				
					commit
					08226ab491
				
			| @@ -1932,12 +1932,16 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             [[[[1]], [[2]], [[3]]]], | ||||
|         ] | ||||
|  | ||||
|         for array in arrays: | ||||
|         mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays] | ||||
|         atleast_arrays = mx.atleast_1d(*mx_arrays) | ||||
|  | ||||
|         for i, array in enumerate(arrays): | ||||
|             mx_res = mx.atleast_1d(mx.array(array)) | ||||
|             np_res = np.atleast_1d(np.array(array)) | ||||
|             self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) | ||||
|             self.assertEqual(mx_res.shape, np_res.shape) | ||||
|             self.assertEqual(mx_res.ndim, np_res.ndim) | ||||
|             self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) | ||||
|  | ||||
|     def test_atleast_2d(self): | ||||
|         def compare_nested_lists(x, y): | ||||
| @@ -1962,12 +1966,16 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             [[[[1]], [[2]], [[3]]]], | ||||
|         ] | ||||
|  | ||||
|         for array in arrays: | ||||
|         mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays] | ||||
|         atleast_arrays = mx.atleast_2d(*mx_arrays) | ||||
|  | ||||
|         for i, array in enumerate(arrays): | ||||
|             mx_res = mx.atleast_2d(mx.array(array)) | ||||
|             np_res = np.atleast_2d(np.array(array)) | ||||
|             self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) | ||||
|             self.assertEqual(mx_res.shape, np_res.shape) | ||||
|             self.assertEqual(mx_res.ndim, np_res.ndim) | ||||
|             self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) | ||||
|  | ||||
|     def test_atleast_3d(self): | ||||
|         def compare_nested_lists(x, y): | ||||
| @@ -1992,12 +2000,16 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|             [[[[1]], [[2]], [[3]]]], | ||||
|         ] | ||||
|  | ||||
|         for array in arrays: | ||||
|         mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays] | ||||
|         atleast_arrays = mx.atleast_3d(*mx_arrays) | ||||
|  | ||||
|         for i, array in enumerate(arrays): | ||||
|             mx_res = mx.atleast_3d(mx.array(array)) | ||||
|             np_res = np.atleast_3d(np.array(array)) | ||||
|             self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) | ||||
|             self.assertEqual(mx_res.shape, np_res.shape) | ||||
|             self.assertEqual(mx_res.ndim, np_res.ndim) | ||||
|             self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user