mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Fix segfault from buffer protocol and tests (#383)
* Fix segfault from buffer protocol and tests * Fix tf test
This commit is contained in:
		 Angelos Katharopoulos
					Angelos Katharopoulos
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							1331fa19f6
						
					
				
				
					commit
					4c48f6460d
				
			| @@ -52,10 +52,21 @@ class MLXTestCase(unittest.TestCase): | ||||
|         atol=1e-2, | ||||
|         rtol=1e-2, | ||||
|     ): | ||||
|         assert tuple(mx_res.shape) == tuple( | ||||
|             expected.shape | ||||
|         ), f"shape mismatch expected={expected.shape} got={mx_res.shape}" | ||||
|         assert ( | ||||
|             mx_res.dtype == expected.dtype | ||||
|         ), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}" | ||||
|         np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) | ||||
|         self.assertEqual( | ||||
|             tuple(mx_res.shape), | ||||
|             tuple(expected.shape), | ||||
|             msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}", | ||||
|         ) | ||||
|         self.assertEqual( | ||||
|             mx_res.dtype, | ||||
|             expected.dtype, | ||||
|             msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}", | ||||
|         ) | ||||
|         if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array): | ||||
|             np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) | ||||
|         elif not isinstance(mx_res, mx.array): | ||||
|             mx_res = mx.array(mx_res) | ||||
|             self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) | ||||
|         elif not isinstance(expected, mx.array): | ||||
|             expected = mx.array(expected) | ||||
|             self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) | ||||
|   | ||||
| @@ -1170,7 +1170,6 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|                     f(a_np), | ||||
|                     atol=0, | ||||
|                     rtol=0, | ||||
|                     msg=f"{mlx_dtype}{np_dtype}", | ||||
|                 ) | ||||
|  | ||||
|         # extra test for bfloat16, which is not numpy convertible | ||||
| @@ -1178,7 +1177,7 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|         mv_mx = memoryview(a_mx) | ||||
|         self.assertEqual(mv_mx.strides, (8, 2)) | ||||
|         self.assertEqual(mv_mx.shape, (3, 4)) | ||||
|         self.assertEqual(mv_mx.format, "") | ||||
|         self.assertEqual(mv_mx.format, "B") | ||||
|         with self.assertRaises(RuntimeError) as cm: | ||||
|             np.array(a_mx) | ||||
|         e = cm.exception | ||||
| @@ -1265,7 +1264,6 @@ class TestArray(mlx_tests.MLXTestCase): | ||||
|                     f(a_tf), | ||||
|                     atol=0, | ||||
|                     rtol=0, | ||||
|                     msg=f"{mlx_dtype}{tf_dtype}", | ||||
|                 ) | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user