mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:38:09 +08:00
Make array conform to the Python Buffer Protocol (#323)
This commit is contained in:
@@ -278,108 +278,6 @@ array array_from_list(T pl, std::optional<Dtype> dtype) {
|
||||
return stack(arrays);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// MLX -> Numpy
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
size_t elem_to_loc(
|
||||
int elem,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<size_t>& strides) {
|
||||
size_t loc = 0;
|
||||
for (int i = shape.size() - 1; i >= 0; --i) {
|
||||
auto q_and_r = ldiv(elem, shape[i]);
|
||||
loc += q_and_r.rem * strides[i];
|
||||
elem = q_and_r.quot;
|
||||
}
|
||||
return loc;
|
||||
}
|
||||
|
||||
struct PyArrayPayload {
|
||||
array a;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
py::array_t<T> mlx_array_to_np_t(const array& src) {
|
||||
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
|
||||
// the data
|
||||
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
|
||||
delete reinterpret_cast<PyArrayPayload*>(payload);
|
||||
});
|
||||
// Collect strides
|
||||
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
|
||||
for (int i = 0; i < src.ndim(); i++) {
|
||||
strides[i] *= src.itemsize();
|
||||
}
|
||||
// Pack the capsule with the array
|
||||
py::array_t<T> out(src.shape(), strides, src.data<T>(), freeWhenDone);
|
||||
// Mark array as read-only
|
||||
py::detail::array_proxy(out.ptr())->flags &=
|
||||
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
||||
// Return array
|
||||
return py::array_t(src.shape(), strides, src.data<T>(), out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
py::array mlx_array_to_np_t(const array& src, const py::dtype& dt) {
|
||||
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
|
||||
// the data
|
||||
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
|
||||
delete reinterpret_cast<PyArrayPayload*>(payload);
|
||||
});
|
||||
// Collect strides
|
||||
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
|
||||
for (int i = 0; i < src.ndim(); i++) {
|
||||
strides[i] *= src.itemsize();
|
||||
}
|
||||
// Pack the capsule with the array
|
||||
py::array out(dt, src.shape(), strides, src.data<T>(), freeWhenDone);
|
||||
// Mark array as read-only
|
||||
py::detail::array_proxy(out.ptr())->flags &=
|
||||
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
|
||||
// Return array
|
||||
return py::array(dt, src.shape(), strides, src.data<T>(), out);
|
||||
}
|
||||
|
||||
py::array mlx_array_to_np(const array& src) {
|
||||
// Eval if not already evaled
|
||||
if (!src.is_evaled()) {
|
||||
eval({src}, src.is_tracer());
|
||||
}
|
||||
|
||||
switch (src.dtype()) {
|
||||
case bool_:
|
||||
return mlx_array_to_np_t<bool>(src);
|
||||
case uint8:
|
||||
return mlx_array_to_np_t<uint8_t>(src);
|
||||
case uint16:
|
||||
return mlx_array_to_np_t<uint16_t>(src);
|
||||
case uint32:
|
||||
return mlx_array_to_np_t<uint32_t>(src);
|
||||
case uint64:
|
||||
return mlx_array_to_np_t<uint64_t>(src);
|
||||
case int8:
|
||||
return mlx_array_to_np_t<int8_t>(src);
|
||||
case int16:
|
||||
return mlx_array_to_np_t<int16_t>(src);
|
||||
case int32:
|
||||
return mlx_array_to_np_t<int32_t>(src);
|
||||
case int64:
|
||||
return mlx_array_to_np_t<int64_t>(src);
|
||||
case float16:
|
||||
return mlx_array_to_np_t<float16_t>(src, py::dtype("float16"));
|
||||
case float32:
|
||||
return mlx_array_to_np_t<float>(src);
|
||||
case bfloat16: {
|
||||
auto a = astype(src, float32);
|
||||
eval({a}, src.is_tracer());
|
||||
return mlx_array_to_np_t<float>(a);
|
||||
}
|
||||
case complex64:
|
||||
return mlx_array_to_np_t<complex64_t>(src, py::dtype("complex64"));
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Numpy -> MLX
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -479,6 +377,61 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Python Buffer Protocol (MLX -> Numpy)
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::optional<std::string> buffer_format(const array& a) {
|
||||
// https://docs.python.org/3.10/library/struct.html#format-characters
|
||||
switch (a.dtype()) {
|
||||
case bool_:
|
||||
return pybind11::format_descriptor<bool>::format();
|
||||
case uint8:
|
||||
return pybind11::format_descriptor<uint8_t>::format();
|
||||
case uint16:
|
||||
return pybind11::format_descriptor<uint16_t>::format();
|
||||
case uint32:
|
||||
return pybind11::format_descriptor<uint32_t>::format();
|
||||
case uint64:
|
||||
return pybind11::format_descriptor<uint64_t>::format();
|
||||
case int8:
|
||||
return pybind11::format_descriptor<int8_t>::format();
|
||||
case int16:
|
||||
return pybind11::format_descriptor<int16_t>::format();
|
||||
case int32:
|
||||
return pybind11::format_descriptor<int32_t>::format();
|
||||
case int64:
|
||||
return pybind11::format_descriptor<int64_t>::format();
|
||||
case float16:
|
||||
// https://github.com/pybind/pybind11/issues/4998
|
||||
return "e";
|
||||
case float32: {
|
||||
return pybind11::format_descriptor<float>::format();
|
||||
}
|
||||
case bfloat16:
|
||||
// not supported by python buffer protocol or numpy.
|
||||
// musst be null according to
|
||||
// https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT
|
||||
return {};
|
||||
case complex64:
|
||||
return pybind11::format_descriptor<std::complex<float>>::format();
|
||||
default: {
|
||||
std::ostringstream os;
|
||||
os << "bad dtype: " << a.dtype();
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> buffer_strides(const array& a) {
|
||||
std::vector<size_t> py_strides;
|
||||
py_strides.reserve(a.strides().size());
|
||||
for (const size_t stride : a.strides()) {
|
||||
py_strides.push_back(stride * a.itemsize());
|
||||
}
|
||||
return py_strides;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Module
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -546,7 +499,10 @@ void init_array(py::module_& m) {
|
||||
m.attr("complex64") = py::cast(complex64);
|
||||
|
||||
auto array_class = py::class_<array>(
|
||||
m, "array", R"pbdoc(An N-dimensional array object.)pbdoc");
|
||||
m,
|
||||
"array",
|
||||
R"pbdoc(An N-dimensional array object.)pbdoc",
|
||||
py::buffer_protocol());
|
||||
|
||||
{
|
||||
py::options options;
|
||||
@@ -564,6 +520,19 @@ void init_array(py::module_& m) {
|
||||
}
|
||||
|
||||
array_class
|
||||
.def_buffer([](array& a) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
eval({a}, a.is_tracer());
|
||||
}
|
||||
return pybind11::buffer_info(
|
||||
a.data<void>(),
|
||||
a.itemsize(),
|
||||
buffer_format(a).value_or(nullptr),
|
||||
a.ndim(),
|
||||
a.shape(),
|
||||
buffer_strides(a));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
||||
.def_property_readonly(
|
||||
@@ -620,7 +589,6 @@ void init_array(py::module_& m) {
|
||||
The value type of the list corresponding to the last dimension is either
|
||||
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
|
||||
)pbdoc")
|
||||
.def("__array__", &mlx_array_to_np)
|
||||
.def(
|
||||
"astype",
|
||||
&astype,
|
||||
|
Reference in New Issue
Block a user