mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	| @@ -2177,6 +2177,38 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                     f"mx and np don't aggree on {a}, {b}", | ||||
|                 ) | ||||
|  | ||||
|     def test_bitwise_ops(self): | ||||
|         types = [ | ||||
|             mx.uint8, | ||||
|             mx.uint16, | ||||
|             mx.uint32, | ||||
|             mx.uint64, | ||||
|             mx.int8, | ||||
|             mx.int16, | ||||
|             mx.int32, | ||||
|             mx.int64, | ||||
|         ] | ||||
|         a = mx.random.randint(0, 4096, (1000,)) | ||||
|         b = mx.random.randint(0, 4096, (1000,)) | ||||
|         for op in ["bitwise_and", "bitwise_or", "bitwise_xor"]: | ||||
|             for t in types: | ||||
|                 a_mlx = a.astype(t) | ||||
|                 b_mlx = b.astype(t) | ||||
|                 a_np = np.array(a_mlx) | ||||
|                 b_np = np.array(b_mlx) | ||||
|                 out_mlx = getattr(mx, op)(a_mlx, b_mlx) | ||||
|                 out_np = getattr(np, op)(a_np, b_np) | ||||
|                 self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) | ||||
|         for op in ["left_shift", "right_shift"]: | ||||
|             for t in types: | ||||
|                 a_mlx = a.astype(t) | ||||
|                 b_mlx = mx.random.randint(0, t.size, (1000,)).astype(t) | ||||
|                 a_np = np.array(a_mlx) | ||||
|                 b_np = np.array(b_mlx) | ||||
|                 out_mlx = getattr(mx, op)(a_mlx, b_mlx) | ||||
|                 out_np = getattr(np, op)(a_np, b_np) | ||||
|                 self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun