Fix segfault from buffer protocol and tests (#383)

* Fix segfault from buffer protocol and tests

* Fix tf test
This commit is contained in:
Angelos Katharopoulos
2024-01-05 18:17:44 -08:00
committed by GitHub
parent 1331fa19f6
commit 4c48f6460d
3 changed files with 26 additions and 12 deletions

View File

@@ -52,10 +52,21 @@ class MLXTestCase(unittest.TestCase):
atol=1e-2,
rtol=1e-2,
):
assert tuple(mx_res.shape) == tuple(
expected.shape
), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
assert (
mx_res.dtype == expected.dtype
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
self.assertEqual(
tuple(mx_res.shape),
tuple(expected.shape),
msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}",
)
self.assertEqual(
mx_res.dtype,
expected.dtype,
msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}",
)
if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
elif not isinstance(mx_res, mx.array):
mx_res = mx.array(mx_res)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
elif not isinstance(expected, mx.array):
expected = mx.array(expected)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))