mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Fix segfault from buffer protocol and tests (#383)
* Fix segfault from buffer protocol and tests * Fix tf test
This commit is contained in:
parent
1331fa19f6
commit
4c48f6460d
@ -208,6 +208,7 @@ using array_init_type = std::variant<
|
|||||||
std::complex<float>,
|
std::complex<float>,
|
||||||
py::list,
|
py::list,
|
||||||
py::tuple,
|
py::tuple,
|
||||||
|
array,
|
||||||
py::array,
|
py::array,
|
||||||
py::buffer,
|
py::buffer,
|
||||||
py::object>;
|
py::object>;
|
||||||
@ -410,8 +411,9 @@ std::optional<std::string> buffer_format(const array& a) {
|
|||||||
}
|
}
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
// not supported by python buffer protocol or numpy.
|
// not supported by python buffer protocol or numpy.
|
||||||
// musst be null according to
|
// must be null according to
|
||||||
// https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT
|
// https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT
|
||||||
|
// which implies 'B'.
|
||||||
return {};
|
return {};
|
||||||
case complex64:
|
case complex64:
|
||||||
return pybind11::format_descriptor<std::complex<float>>::format();
|
return pybind11::format_descriptor<std::complex<float>>::format();
|
||||||
@ -449,6 +451,8 @@ array create_array(array_init_type v, std::optional<Dtype> t) {
|
|||||||
return array_from_list(*pv, t);
|
return array_from_list(*pv, t);
|
||||||
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
||||||
return array_from_list(*pv, t);
|
return array_from_list(*pv, t);
|
||||||
|
} else if (auto pv = std::get_if<array>(&v); pv) {
|
||||||
|
return astype(*pv, t.value_or((*pv).dtype()));
|
||||||
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
||||||
return np_array_to_mlx(*pv, t);
|
return np_array_to_mlx(*pv, t);
|
||||||
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
||||||
@ -528,7 +532,8 @@ void init_array(py::module_& m) {
|
|||||||
return pybind11::buffer_info(
|
return pybind11::buffer_info(
|
||||||
a.data<void>(),
|
a.data<void>(),
|
||||||
a.itemsize(),
|
a.itemsize(),
|
||||||
buffer_format(a).value_or(nullptr),
|
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
|
||||||
|
// std::string which can't be null
|
||||||
a.ndim(),
|
a.ndim(),
|
||||||
a.shape(),
|
a.shape(),
|
||||||
buffer_strides(a));
|
buffer_strides(a));
|
||||||
|
@ -52,10 +52,21 @@ class MLXTestCase(unittest.TestCase):
|
|||||||
atol=1e-2,
|
atol=1e-2,
|
||||||
rtol=1e-2,
|
rtol=1e-2,
|
||||||
):
|
):
|
||||||
assert tuple(mx_res.shape) == tuple(
|
self.assertEqual(
|
||||||
expected.shape
|
tuple(mx_res.shape),
|
||||||
), f"shape mismatch expected={expected.shape} got={mx_res.shape}"
|
tuple(expected.shape),
|
||||||
assert (
|
msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}",
|
||||||
mx_res.dtype == expected.dtype
|
)
|
||||||
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}"
|
self.assertEqual(
|
||||||
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|
mx_res.dtype,
|
||||||
|
expected.dtype,
|
||||||
|
msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}",
|
||||||
|
)
|
||||||
|
if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):
|
||||||
|
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
|
||||||
|
elif not isinstance(mx_res, mx.array):
|
||||||
|
mx_res = mx.array(mx_res)
|
||||||
|
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
|
||||||
|
elif not isinstance(expected, mx.array):
|
||||||
|
expected = mx.array(expected)
|
||||||
|
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
|
||||||
|
@ -1170,7 +1170,6 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
f(a_np),
|
f(a_np),
|
||||||
atol=0,
|
atol=0,
|
||||||
rtol=0,
|
rtol=0,
|
||||||
msg=f"{mlx_dtype}{np_dtype}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# extra test for bfloat16, which is not numpy convertible
|
# extra test for bfloat16, which is not numpy convertible
|
||||||
@ -1178,7 +1177,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
mv_mx = memoryview(a_mx)
|
mv_mx = memoryview(a_mx)
|
||||||
self.assertEqual(mv_mx.strides, (8, 2))
|
self.assertEqual(mv_mx.strides, (8, 2))
|
||||||
self.assertEqual(mv_mx.shape, (3, 4))
|
self.assertEqual(mv_mx.shape, (3, 4))
|
||||||
self.assertEqual(mv_mx.format, "")
|
self.assertEqual(mv_mx.format, "B")
|
||||||
with self.assertRaises(RuntimeError) as cm:
|
with self.assertRaises(RuntimeError) as cm:
|
||||||
np.array(a_mx)
|
np.array(a_mx)
|
||||||
e = cm.exception
|
e = cm.exception
|
||||||
@ -1265,7 +1264,6 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
f(a_tf),
|
f(a_tf),
|
||||||
atol=0,
|
atol=0,
|
||||||
rtol=0,
|
rtol=0,
|
||||||
msg=f"{mlx_dtype}{tf_dtype}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user