mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
Fix segfault from buffer protocol and tests (#383)
* Fix segfault from buffer protocol and tests * Fix tf test
This commit is contained in:

committed by
GitHub

parent
1331fa19f6
commit
4c48f6460d
@@ -52,10 +52,21 @@ class MLXTestCase(unittest.TestCase):
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
):
|
||||
assert tuple(mx_res.shape) == tuple(
|
||||
expected.shape
|
||||
), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
|
||||
assert (
|
||||
mx_res.dtype == expected.dtype
|
||||
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
|
||||
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|
||||
self.assertEqual(
|
||||
tuple(mx_res.shape),
|
||||
tuple(expected.shape),
|
||||
msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}",
|
||||
)
|
||||
self.assertEqual(
|
||||
mx_res.dtype,
|
||||
expected.dtype,
|
||||
msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}",
|
||||
)
|
||||
if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):
|
||||
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|
||||
elif not isinstance(mx_res, mx.array):
|
||||
mx_res = mx.array(mx_res)
|
||||
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
|
||||
elif not isinstance(expected, mx.array):
|
||||
expected = mx.array(expected)
|
||||
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
|
||||
|
Reference in New Issue
Block a user