mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Fix reduce edge case (#1389)
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							9592766939
						
					
				
				
					commit
					969337345f
				
			| @@ -10,21 +10,21 @@ import numpy as np | ||||
|  | ||||
| class TestReduce(mlx_tests.MLXTestCase): | ||||
|     def test_axis_permutation_sums(self): | ||||
|         x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32) | ||||
|         x_mlx = mx.array(x_npy) | ||||
|         for t in permutations(range(5)): | ||||
|             with self.subTest(t=t): | ||||
|                 y_npy = np.transpose(x_npy, t) | ||||
|                 y_mlx = mx.transpose(x_mlx, t) | ||||
|                 for n in range(1, 6): | ||||
|                     for a in combinations(range(5), n): | ||||
|                         with self.subTest(a=a): | ||||
|                             z_npy = np.sum(y_npy, axis=a) | ||||
|                             z_mlx = mx.sum(y_mlx, axis=a) | ||||
|                             mx.eval(z_mlx) | ||||
|                             self.assertTrue( | ||||
|                                 np.allclose(z_npy, np.array(z_mlx), atol=1e-4) | ||||
|                             ) | ||||
|         for shape in [(5, 5, 1, 5, 5), (65, 65, 1, 65)]: | ||||
|             with self.subTest(shape=shape): | ||||
|                 x_npy = (np.random.randn(*shape) * 128).astype(np.int32) | ||||
|                 x_mlx = mx.array(x_npy) | ||||
|                 for t in permutations(range(len(shape))): | ||||
|                     with self.subTest(t=t): | ||||
|                         y_npy = np.transpose(x_npy, t) | ||||
|                         y_mlx = mx.transpose(x_mlx, t) | ||||
|                         for n in range(1, len(shape) + 1): | ||||
|                             for a in combinations(range(len(shape)), n): | ||||
|                                 with self.subTest(a=a): | ||||
|                                     z_npy = np.sum(y_npy, axis=a) | ||||
|                                     z_mlx = mx.sum(y_mlx, axis=a) | ||||
|                                     mx.eval(z_mlx) | ||||
|                                     self.assertTrue(np.all(z_npy == z_mlx)) | ||||
|  | ||||
|     def test_expand_sums(self): | ||||
|         x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user