Fix segfault from buffer protocol and tests (#383)

* Fix segfault from buffer protocol and tests

* Fix tf test
This commit is contained in:
Angelos Katharopoulos
2024-01-05 18:17:44 -08:00
committed by GitHub
parent 1331fa19f6
commit 4c48f6460d
3 changed files with 26 additions and 12 deletions

View File

@@ -1170,7 +1170,6 @@ class TestArray(mlx_tests.MLXTestCase):
f(a_np),
atol=0,
rtol=0,
msg=f"{mlx_dtype}{np_dtype}",
)
# extra test for bfloat16, which is not numpy convertible
@@ -1178,7 +1177,7 @@ class TestArray(mlx_tests.MLXTestCase):
mv_mx = memoryview(a_mx)
self.assertEqual(mv_mx.strides, (8, 2))
self.assertEqual(mv_mx.shape, (3, 4))
self.assertEqual(mv_mx.format, "")
self.assertEqual(mv_mx.format, "B")
with self.assertRaises(RuntimeError) as cm:
np.array(a_mx)
e = cm.exception
@@ -1265,7 +1264,6 @@ class TestArray(mlx_tests.MLXTestCase):
f(a_tf),
atol=0,
rtol=0,
msg=f"{mlx_dtype}{tf_dtype}",
)