mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Fix segfault from buffer protocol and tests (#383)
* Fix segfault from buffer protocol and tests * Fix tf test
This commit is contained in:

committed by
GitHub

parent
1331fa19f6
commit
4c48f6460d
@@ -208,6 +208,7 @@ using array_init_type = std::variant<
|
||||
std::complex<float>,
|
||||
py::list,
|
||||
py::tuple,
|
||||
array,
|
||||
py::array,
|
||||
py::buffer,
|
||||
py::object>;
|
||||
@@ -410,8 +411,9 @@ std::optional<std::string> buffer_format(const array& a) {
|
||||
}
|
||||
case bfloat16:
|
||||
// not supported by python buffer protocol or numpy.
|
||||
// musst be null according to
|
||||
// must be null according to
|
||||
// https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT
|
||||
// which implies 'B'.
|
||||
return {};
|
||||
case complex64:
|
||||
return pybind11::format_descriptor<std::complex<float>>::format();
|
||||
@@ -449,6 +451,8 @@ array create_array(array_init_type v, std::optional<Dtype> t) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<array>(&v); pv) {
|
||||
return astype(*pv, t.value_or((*pv).dtype()));
|
||||
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
||||
return np_array_to_mlx(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
||||
@@ -528,7 +532,8 @@ void init_array(py::module_& m) {
|
||||
return pybind11::buffer_info(
|
||||
a.data<void>(),
|
||||
a.itemsize(),
|
||||
buffer_format(a).value_or(nullptr),
|
||||
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
|
||||
// std::string which can't be null
|
||||
a.ndim(),
|
||||
a.shape(),
|
||||
buffer_strides(a));
|
||||
|
Reference in New Issue
Block a user