mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 18:48:15 +08:00 
			
		
		
		
	Working 64-bit scans (#1506)
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							32972a5924
						
					
				
				
					commit
					c9b41d460f
				
			@@ -1758,6 +1758,18 @@ class TestOps(mlx_tests.MLXTestCase):
 | 
			
		||||
                c_mlx = mxop(a_mlx, axis=axis)
 | 
			
		||||
                self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-3, atol=1e-3))
 | 
			
		||||
 | 
			
		||||
        a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
 | 
			
		||||
        for dt in [mx.int32, mx.int64]:
 | 
			
		||||
            mxx = a_mlx.astype(dt)
 | 
			
		||||
            npx = np.array(mxx)
 | 
			
		||||
            for op in ["cumsum", "cumprod"]:
 | 
			
		||||
                npop = getattr(np, op)
 | 
			
		||||
                mxop = getattr(mx, op)
 | 
			
		||||
                for axis in (None, 0, 1, 2):
 | 
			
		||||
                    c_npy = npop(npx, axis=axis, dtype=npx.dtype)
 | 
			
		||||
                    c_mlx = mxop(mxx, axis=axis)
 | 
			
		||||
                    self.assertTrue(np.array_equal(c_npy, c_mlx))
 | 
			
		||||
 | 
			
		||||
        a_mlx = mx.random.randint(shape=(32, 32, 32), low=-100, high=100)
 | 
			
		||||
        for op in ["cumsum", "cumprod", "cummax", "cummin"]:
 | 
			
		||||
            mxop = getattr(mx, op)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user