mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Check edge case handling in row reduce med kernel (#858)
This commit is contained in:
		@@ -590,6 +590,18 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
            self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape))
 | 
			
		||||
            self.assertTrue(np.all(sum_npy == sum_mlx))
 | 
			
		||||
 | 
			
		||||
        x_npy = (
 | 
			
		||||
            np.arange(3 * 2 * 3 * 3 * 3 * 3)
 | 
			
		||||
            .reshape(3, 2, 3, 3, 3, 3)
 | 
			
		||||
            .astype(np.float32)
 | 
			
		||||
        )
 | 
			
		||||
        x_mlx = mx.array(x_npy)
 | 
			
		||||
 | 
			
		||||
        y_mlx = x_mlx.sum(axis=(0, 1, 3, 4, 5))
 | 
			
		||||
        y_npy = x_npy.sum(axis=(0, 1, 3, 4, 5))
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(np.array_equal(y_mlx, y_npy))
 | 
			
		||||
 | 
			
		||||
    def test_prod(self):
 | 
			
		||||
        x = mx.array(
 | 
			
		||||
            [
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user