Make array conform to the Python Buffer Protocol (#323)

This commit is contained in:
Daniel Strobusch
2024-01-06 00:58:33 +01:00
committed by GitHub
parent dfdb284e16
commit 1331fa19f6
10 changed files with 343 additions and 109 deletions

View File

@@ -278,108 +278,6 @@ array array_from_list(T pl, std::optional<Dtype> dtype) {
return stack(arrays);
}
///////////////////////////////////////////////////////////////////////////////
// MLX -> Numpy
///////////////////////////////////////////////////////////////////////////////
size_t elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
size_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
elem = q_and_r.quot;
}
return loc;
}
struct PyArrayPayload {
array a;
};
template <typename T>
py::array_t<T> mlx_array_to_np_t(const array& src) {
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
// the data
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
delete reinterpret_cast<PyArrayPayload*>(payload);
});
// Collect strides
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
for (int i = 0; i < src.ndim(); i++) {
strides[i] *= src.itemsize();
}
// Pack the capsule with the array
py::array_t<T> out(src.shape(), strides, src.data<T>(), freeWhenDone);
// Mark array as read-only
py::detail::array_proxy(out.ptr())->flags &=
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
// Return array
return py::array_t(src.shape(), strides, src.data<T>(), out);
}
template <typename T>
py::array mlx_array_to_np_t(const array& src, const py::dtype& dt) {
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
// the data
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
delete reinterpret_cast<PyArrayPayload*>(payload);
});
// Collect strides
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
for (int i = 0; i < src.ndim(); i++) {
strides[i] *= src.itemsize();
}
// Pack the capsule with the array
py::array out(dt, src.shape(), strides, src.data<T>(), freeWhenDone);
// Mark array as read-only
py::detail::array_proxy(out.ptr())->flags &=
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
// Return array
return py::array(dt, src.shape(), strides, src.data<T>(), out);
}
py::array mlx_array_to_np(const array& src) {
// Eval if not already evaled
if (!src.is_evaled()) {
eval({src}, src.is_tracer());
}
switch (src.dtype()) {
case bool_:
return mlx_array_to_np_t<bool>(src);
case uint8:
return mlx_array_to_np_t<uint8_t>(src);
case uint16:
return mlx_array_to_np_t<uint16_t>(src);
case uint32:
return mlx_array_to_np_t<uint32_t>(src);
case uint64:
return mlx_array_to_np_t<uint64_t>(src);
case int8:
return mlx_array_to_np_t<int8_t>(src);
case int16:
return mlx_array_to_np_t<int16_t>(src);
case int32:
return mlx_array_to_np_t<int32_t>(src);
case int64:
return mlx_array_to_np_t<int64_t>(src);
case float16:
return mlx_array_to_np_t<float16_t>(src, py::dtype("float16"));
case float32:
return mlx_array_to_np_t<float>(src);
case bfloat16: {
auto a = astype(src, float32);
eval({a}, src.is_tracer());
return mlx_array_to_np_t<float>(a);
}
case complex64:
return mlx_array_to_np_t<complex64_t>(src, py::dtype("complex64"));
}
}
///////////////////////////////////////////////////////////////////////////////
// Numpy -> MLX
///////////////////////////////////////////////////////////////////////////////
@@ -479,6 +377,61 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
}
}
///////////////////////////////////////////////////////////////////////////////
// Python Buffer Protocol (MLX -> Numpy)
///////////////////////////////////////////////////////////////////////////////
std::optional<std::string> buffer_format(const array& a) {
// https://docs.python.org/3.10/library/struct.html#format-characters
switch (a.dtype()) {
case bool_:
return pybind11::format_descriptor<bool>::format();
case uint8:
return pybind11::format_descriptor<uint8_t>::format();
case uint16:
return pybind11::format_descriptor<uint16_t>::format();
case uint32:
return pybind11::format_descriptor<uint32_t>::format();
case uint64:
return pybind11::format_descriptor<uint64_t>::format();
case int8:
return pybind11::format_descriptor<int8_t>::format();
case int16:
return pybind11::format_descriptor<int16_t>::format();
case int32:
return pybind11::format_descriptor<int32_t>::format();
case int64:
return pybind11::format_descriptor<int64_t>::format();
case float16:
// https://github.com/pybind/pybind11/issues/4998
return "e";
case float32: {
return pybind11::format_descriptor<float>::format();
}
case bfloat16:
// not supported by python buffer protocol or numpy.
// musst be null according to
// https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT
return {};
case complex64:
return pybind11::format_descriptor<std::complex<float>>::format();
default: {
std::ostringstream os;
os << "bad dtype: " << a.dtype();
throw std::runtime_error(os.str());
}
}
}
std::vector<size_t> buffer_strides(const array& a) {
std::vector<size_t> py_strides;
py_strides.reserve(a.strides().size());
for (const size_t stride : a.strides()) {
py_strides.push_back(stride * a.itemsize());
}
return py_strides;
}
///////////////////////////////////////////////////////////////////////////////
// Module
///////////////////////////////////////////////////////////////////////////////
@@ -546,7 +499,10 @@ void init_array(py::module_& m) {
m.attr("complex64") = py::cast(complex64);
auto array_class = py::class_<array>(
m, "array", R"pbdoc(An N-dimensional array object.)pbdoc");
m,
"array",
R"pbdoc(An N-dimensional array object.)pbdoc",
py::buffer_protocol());
{
py::options options;
@@ -564,6 +520,19 @@ void init_array(py::module_& m) {
}
array_class
.def_buffer([](array& a) {
// Eval if not already evaled
if (!a.is_evaled()) {
eval({a}, a.is_tracer());
}
return pybind11::buffer_info(
a.data<void>(),
a.itemsize(),
buffer_format(a).value_or(nullptr),
a.ndim(),
a.shape(),
buffer_strides(a));
})
.def_property_readonly(
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
.def_property_readonly(
@@ -620,7 +589,6 @@ void init_array(py::module_& m) {
The value type of the list corresponding to the last dimension is either
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
)pbdoc")
.def("__array__", &mlx_array_to_np)
.def(
"astype",
&astype,

View File

@@ -2,12 +2,20 @@
import operator
import unittest
import weakref
from itertools import permutations
import mlx.core as mx
import mlx_tests
import numpy as np
try:
import tensorflow as tf
has_tf = True
except ImportError as e:
has_tf = False
class TestVersion(mlx_tests.MLXTestCase):
def test_version(self):
@@ -1100,7 +1108,7 @@ class TestArray(mlx_tests.MLXTestCase):
# Check that we get read-only array that does not own the underlying data
self.assertFalse(a_np.flags.owndata)
self.assertFalse(a_np.flags.writeable)
self.assertTrue(a_np.flags.writeable)
# Check contents
self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np))
@@ -1109,6 +1117,157 @@ class TestArray(mlx_tests.MLXTestCase):
# Check strides
self.assertSequenceEqual(b_np.strides, (0, 8, 4))
def test_np_array_conversion_copies_by_default(self):
a_mx = mx.ones((2, 2))
a_np = np.array(a_mx)
self.assertTrue(a_np.flags.owndata)
self.assertTrue(a_np.flags.writeable)
def test_buffer_protocol(self):
dtypes_list = [
(mx.bool_, np.bool_, None),
(mx.uint8, np.uint8, np.iinfo),
(mx.uint16, np.uint16, np.iinfo),
(mx.uint32, np.uint32, np.iinfo),
(mx.uint64, np.uint64, np.iinfo),
(mx.int8, np.int8, np.iinfo),
(mx.int16, np.int16, np.iinfo),
(mx.int32, np.int32, np.iinfo),
(mx.int64, np.int64, np.iinfo),
(mx.float16, np.float16, np.finfo),
(mx.float32, np.float32, np.finfo),
(mx.complex64, np.complex64, np.finfo),
]
for mlx_dtype, np_dtype, info_fn in dtypes_list:
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
if info_fn is not None:
info = info_fn(np_dtype)
a_np[0, 0] = info.min
a_np[0, 1] = info.max
a_mx = mx.array(a_np)
for f in [lambda x: x, lambda x: x.T]:
mv_mx = memoryview(f(a_mx))
mv_np = memoryview(f(a_np))
self.assertEqual(mv_mx.strides, mv_np.strides, f"{mlx_dtype}{np_dtype}")
self.assertEqual(mv_mx.shape, mv_np.shape, f"{mlx_dtype}{np_dtype}")
# correct buffer format for 8 byte (unsigned) 'long long' is Q/q, see
# https://docs.python.org/3.10/library/struct.html#format-characters
# numpy returns L/l, as 'long' is equivalent to 'long long' on 64bit machines, so q and l are equivalent
# see https://github.com/pybind/pybind11/issues/1908
if np_dtype == np.uint64:
self.assertEqual(mv_mx.format, "Q", f"{mlx_dtype}{np_dtype}")
elif np_dtype == np.int64:
self.assertEqual(mv_mx.format, "q", f"{mlx_dtype}{np_dtype}")
else:
self.assertEqual(
mv_mx.format, mv_np.format, f"{mlx_dtype}{np_dtype}"
)
self.assertFalse(mv_mx.readonly)
back_to_npy = np.array(mv_mx, copy=False)
self.assertEqualArray(
back_to_npy,
f(a_np),
atol=0,
rtol=0,
msg=f"{mlx_dtype}{np_dtype}",
)
# extra test for bfloat16, which is not numpy convertible
a_mx = mx.random.uniform(low=0, high=100, shape=(3, 4), dtype=mx.bfloat16)
mv_mx = memoryview(a_mx)
self.assertEqual(mv_mx.strides, (8, 2))
self.assertEqual(mv_mx.shape, (3, 4))
self.assertEqual(mv_mx.format, "")
with self.assertRaises(RuntimeError) as cm:
np.array(a_mx)
e = cm.exception
self.assertTrue("Item size 2 for PEP 3118 buffer format string" in str(e))
def test_buffer_protocol_ref_counting(self):
a = mx.arange(3)
wr = weakref.ref(a)
self.assertIsNotNone(wr())
mv = memoryview(a)
a = None
self.assertIsNotNone(wr())
mv = None
self.assertIsNone(wr())
def test_array_view_ref_counting(self):
a = mx.arange(3)
wr = weakref.ref(a)
self.assertIsNotNone(wr())
a_np = np.array(a, copy=False)
a = None
self.assertIsNotNone(wr())
a_np = None
self.assertIsNone(wr())
@unittest.skipIf(not has_tf, "requires TensorFlow")
def test_buffer_protocol_tf(self):
dtypes_list = [
(
mx.bool_,
tf.bool,
np.bool_,
),
(
mx.uint8,
tf.uint8,
np.uint8,
),
(
mx.uint16,
tf.uint16,
np.uint16,
),
(
mx.uint32,
tf.uint32,
np.uint32,
),
(mx.uint64, tf.uint64, np.uint64),
(mx.int8, tf.int8, np.int8),
(mx.int16, tf.int16, np.int16),
(mx.int32, tf.int32, np.int32),
(mx.int64, tf.int64, np.int64),
(mx.float16, tf.float16, np.float16),
(mx.float32, tf.float32, np.float32),
(
mx.complex64,
tf.complex64,
np.complex64,
),
]
for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
a_tf = tf.constant(a_np, dtype=tf_dtype)
a_mx = mx.array(a_tf)
for f in [
lambda x: x,
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,
]:
mv_mx = memoryview(f(a_mx))
mv_tf = memoryview(f(a_tf))
if (mv_mx.c_contiguous and mv_tf.c_contiguous) or (
mv_mx.f_contiguous and mv_tf.f_contiguous
):
self.assertEqual(
mv_mx.strides, mv_tf.strides, f"{mlx_dtype}{tf_dtype}"
)
self.assertEqual(mv_mx.shape, mv_tf.shape, f"{mlx_dtype}{tf_dtype}")
self.assertFalse(mv_mx.readonly)
back_to_npy = np.array(mv_mx)
self.assertEqualArray(
back_to_npy,
f(a_tf),
atol=0,
rtol=0,
msg=f"{mlx_dtype}{tf_dtype}",
)
if __name__ == "__main__":
unittest.main()