mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	| @@ -2333,6 +2333,30 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         out_np = a.conj() | ||||
|         self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) | ||||
|  | ||||
|     def test_view(self): | ||||
|         a = mx.random.randint(shape=(4, 2, 4), low=-100, high=100) | ||||
|         a_np = np.array(a) | ||||
|  | ||||
|         for t in ["bool_", "int16", "float32", "int64"]: | ||||
|             out = a.view(getattr(mx, t)) | ||||
|             expected = a_np.view(getattr(np, t)) | ||||
|             self.assertTrue(np.array_equal(out, expected, equal_nan=True)) | ||||
|  | ||||
|         # Irregular strides | ||||
|         a = mx.random.randint(shape=(2, 4), low=-100, high=100) | ||||
|         a = mx.broadcast_to(a, shape=(4, 2, 4)) | ||||
|  | ||||
|         for t in ["bool_", "int16", "float32", "int64"]: | ||||
|             out = a.view(getattr(mx, t)) | ||||
|             a_out = out.view(mx.int32) | ||||
|             self.assertTrue(mx.array_equal(a_out, a, equal_nan=True)) | ||||
|  | ||||
|         a = mx.random.randint(shape=(4, 4), low=-100, high=100).T | ||||
|         for t in ["bool_", "int16", "float32", "int64"]: | ||||
|             out = a.view(getattr(mx, t)) | ||||
|             a_out = out.view(mx.int32) | ||||
|             self.assertTrue(mx.array_equal(a_out, a, equal_nan=True)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun