Fix segfault from buffer protocol and tests (#383)

* Fix segfault from buffer protocol and tests

* Fix tf test
This commit is contained in:
Angelos Katharopoulos 2024-01-05 18:17:44 -08:00 committed by GitHub
parent 1331fa19f6
commit 4c48f6460d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 12 deletions

View File

@ -208,6 +208,7 @@ using array_init_type = std::variant<
std::complex<float>, std::complex<float>,
py::list, py::list,
py::tuple, py::tuple,
array,
py::array, py::array,
py::buffer, py::buffer,
py::object>; py::object>;
@ -410,8 +411,9 @@ std::optional<std::string> buffer_format(const array& a) {
} }
case bfloat16: case bfloat16:
// not supported by python buffer protocol or numpy. // not supported by python buffer protocol or numpy.
// musst be null according to // must be null according to
// https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT // https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT
// which implies 'B'.
return {}; return {};
case complex64: case complex64:
return pybind11::format_descriptor<std::complex<float>>::format(); return pybind11::format_descriptor<std::complex<float>>::format();
@ -449,6 +451,8 @@ array create_array(array_init_type v, std::optional<Dtype> t) {
return array_from_list(*pv, t); return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::tuple>(&v); pv) { } else if (auto pv = std::get_if<py::tuple>(&v); pv) {
return array_from_list(*pv, t); return array_from_list(*pv, t);
} else if (auto pv = std::get_if<array>(&v); pv) {
return astype(*pv, t.value_or((*pv).dtype()));
} else if (auto pv = std::get_if<py::array>(&v); pv) { } else if (auto pv = std::get_if<py::array>(&v); pv) {
return np_array_to_mlx(*pv, t); return np_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<py::buffer>(&v); pv) { } else if (auto pv = std::get_if<py::buffer>(&v); pv) {
@ -528,7 +532,8 @@ void init_array(py::module_& m) {
return pybind11::buffer_info( return pybind11::buffer_info(
a.data<void>(), a.data<void>(),
a.itemsize(), a.itemsize(),
buffer_format(a).value_or(nullptr), buffer_format(a).value_or("B"), // we use "B" because pybind uses a
// std::string which can't be null
a.ndim(), a.ndim(),
a.shape(), a.shape(),
buffer_strides(a)); buffer_strides(a));

View File

@ -52,10 +52,21 @@ class MLXTestCase(unittest.TestCase):
atol=1e-2, atol=1e-2,
rtol=1e-2, rtol=1e-2,
): ):
assert tuple(mx_res.shape) == tuple( self.assertEqual(
expected.shape tuple(mx_res.shape),
), f"shape mismatch expected={expected.shape} got={mx_res.shape}" tuple(expected.shape),
assert ( msg=f"shape mismatch expected={expected.shape} got={mx_res.shape}",
mx_res.dtype == expected.dtype )
), f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}" self.assertEqual(
mx_res.dtype,
expected.dtype,
msg=f"dtype mismatch expected={expected.dtype} got={mx_res.dtype}",
)
if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
elif not isinstance(mx_res, mx.array):
mx_res = mx.array(mx_res)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
elif not isinstance(expected, mx.array):
expected = mx.array(expected)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))

View File

@ -1170,7 +1170,6 @@ class TestArray(mlx_tests.MLXTestCase):
f(a_np), f(a_np),
atol=0, atol=0,
rtol=0, rtol=0,
msg=f"{mlx_dtype}{np_dtype}",
) )
# extra test for bfloat16, which is not numpy convertible # extra test for bfloat16, which is not numpy convertible
@ -1178,7 +1177,7 @@ class TestArray(mlx_tests.MLXTestCase):
mv_mx = memoryview(a_mx) mv_mx = memoryview(a_mx)
self.assertEqual(mv_mx.strides, (8, 2)) self.assertEqual(mv_mx.strides, (8, 2))
self.assertEqual(mv_mx.shape, (3, 4)) self.assertEqual(mv_mx.shape, (3, 4))
self.assertEqual(mv_mx.format, "") self.assertEqual(mv_mx.format, "B")
with self.assertRaises(RuntimeError) as cm: with self.assertRaises(RuntimeError) as cm:
np.array(a_mx) np.array(a_mx)
e = cm.exception e = cm.exception
@ -1265,7 +1264,6 @@ class TestArray(mlx_tests.MLXTestCase):
f(a_tf), f(a_tf),
atol=0, atol=0,
rtol=0, rtol=0,
msg=f"{mlx_dtype}{tf_dtype}",
) )