Fix segfault from buffer protocol and tests (#383)

* Fix segfault from buffer protocol and tests

* Fix tf test
This commit is contained in:
Angelos Katharopoulos
2024-01-05 18:17:44 -08:00
committed by GitHub
parent 1331fa19f6
commit 4c48f6460d
3 changed files with 26 additions and 12 deletions

View File

@@ -52,10 +52,21 @@ class MLXTestCase(unittest.TestCase):
atol=1e-2,
rtol=1e-2,
):
assert tuple(mx_res.shape) == tuple(
expected.shape
), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
assert (
mx_res.dtype == expected.dtype
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
self.assertEqual(
tuple(mx_res.shape),
tuple(expected.shape),
msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}",
)
self.assertEqual(
mx_res.dtype,
expected.dtype,
msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}",
)
if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
elif not isinstance(mx_res, mx.array):
mx_res = mx.array(mx_res)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
elif not isinstance(expected, mx.array):
expected = mx.array(expected)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))

View File

@@ -1170,7 +1170,6 @@ class TestArray(mlx_tests.MLXTestCase):
f(a_np),
atol=0,
rtol=0,
msg=f"{mlx_dtype}{np_dtype}",
)
# extra test for bfloat16, which is not numpy convertible
@@ -1178,7 +1177,7 @@ class TestArray(mlx_tests.MLXTestCase):
mv_mx = memoryview(a_mx)
self.assertEqual(mv_mx.strides, (8, 2))
self.assertEqual(mv_mx.shape, (3, 4))
self.assertEqual(mv_mx.format, "")
self.assertEqual(mv_mx.format, "B")
with self.assertRaises(RuntimeError) as cm:
np.array(a_mx)
e = cm.exception
@@ -1265,7 +1264,6 @@ class TestArray(mlx_tests.MLXTestCase):
f(a_tf),
atol=0,
rtol=0,
msg=f"{mlx_dtype}{tf_dtype}",
)