Make array conform to the Python Buffer Protocol (#323)

This commit is contained in:
Daniel Strobusch 2024-01-06 00:58:33 +01:00 committed by GitHub
parent dfdb284e16
commit 1331fa19f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 343 additions and 109 deletions

View File

@ -62,6 +62,7 @@ jobs:
pip install --upgrade pybind11[global] pip install --upgrade pybind11[global]
pip install numpy pip install numpy
pip install torch pip install torch
pip install tensorflow
pip install unittest-xml-reporting pip install unittest-xml-reporting
- run: - run:
name: Build python package name: Build python package

View File

@ -61,7 +61,7 @@ variety of examples, including:
## Quickstart ## Quickstart
See the [quick start See the [quick start
guide](https://ml-explore.github.io/mlx/build/html/quick_start.html) guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
in the documentation. in the documentation.
## Installation ## Installation

View File

@ -35,9 +35,10 @@ are the CPU and GPU.
:caption: Usage :caption: Usage
:maxdepth: 1 :maxdepth: 1
quick_start usage/quick_start
unified_memory usage/unified_memory
using_streams usage/using_streams
usage/numpy
.. toctree:: .. toctree::
:caption: Examples :caption: Examples

103
docs/src/usage/numpy.rst Normal file
View File

@ -0,0 +1,103 @@
.. _numpy:
Conversion to NumPy and Other Frameworks
========================================
MLX array implements the `Python Buffer Protocol <https://docs.python.org/3/c-api/buffer.html>`_.
Let's convert an array to NumPy and back.
.. code-block:: python
import mlx.core as mx
import numpy as np
a = mx.arange(3)
b = np.array(a) # copy of a
c = mx.array(b) # copy of b
.. note::
Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first:
``np.array(a.astype(mx.float32))``.
Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.``
By default, NumPy copies data to a new array. This can be prevented by creating an array view:
.. code-block:: python
a = mx.arange(3)
a_view = np.array(a, copy=False)
print(a_view.flags.owndata) # False
a_view[0] = 1
print(a[0].item()) # 1
A NumPy array view is a normal NumPy array, except that it does not own its memory.
This means writing to the view is reflected in the original array.
While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients.
Let's demonstrate this in an example:
.. code-block:: python
def f(x):
x_view = np.array(x, copy=False)
x_view[:] *= x_view # modify memory without telling mx
return x.sum()
x = mx.array([3.0])
y, df = mx.value_and_grad(f)(x)
print("f(x) = x² =", y.item()) # 9.0
print("f'(x) = 2x !=", df.item()) # 1.0
The function ``f`` indirectly modifies the array ``x`` through a memory view.
However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``,
representing the gradient of the sum operation alone.
The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated.
It's important to note that a similar issue arises during array conversion and copying.
For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient,
even though no in-place operations on MLX memory are executed.
PyTorch
-------
PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python
import mlx.core as mx
import torch
a = mx.arange(3)
b = torch.tensor(memoryview(a))
c = mx.array(b.numpy())
Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``.
JAX
---
JAX fully supports the buffer protocol.
.. code-block:: python
import mlx.core as mx
import jax.numpy as jnp
a = mx.arange(3)
b = jnp.array(a)
c = mx.array(b)
TensorFlow
----------
TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`.
.. code-block:: python
import mlx.core as mx
import tensorflow as tf
a = mx.arange(3)
b = tf.constant(memoryview(a))
c = mx.array(b)

View File

@ -278,108 +278,6 @@ array array_from_list(T pl, std::optional<Dtype> dtype) {
return stack(arrays); return stack(arrays);
} }
///////////////////////////////////////////////////////////////////////////////
// MLX -> Numpy
///////////////////////////////////////////////////////////////////////////////
size_t elem_to_loc(
int elem,
const std::vector<int>& shape,
const std::vector<size_t>& strides) {
size_t loc = 0;
for (int i = shape.size() - 1; i >= 0; --i) {
auto q_and_r = ldiv(elem, shape[i]);
loc += q_and_r.rem * strides[i];
elem = q_and_r.quot;
}
return loc;
}
struct PyArrayPayload {
array a;
};
template <typename T>
py::array_t<T> mlx_array_to_np_t(const array& src) {
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
// the data
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
delete reinterpret_cast<PyArrayPayload*>(payload);
});
// Collect strides
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
for (int i = 0; i < src.ndim(); i++) {
strides[i] *= src.itemsize();
}
// Pack the capsule with the array
py::array_t<T> out(src.shape(), strides, src.data<T>(), freeWhenDone);
// Mark array as read-only
py::detail::array_proxy(out.ptr())->flags &=
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
// Return array
return py::array_t(src.shape(), strides, src.data<T>(), out);
}
template <typename T>
py::array mlx_array_to_np_t(const array& src, const py::dtype& dt) {
// Let py::capsule hold onto a copy of the array which holds a shared ptr to
// the data
const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) {
delete reinterpret_cast<PyArrayPayload*>(payload);
});
// Collect strides
std::vector<size_t> strides{src.strides().begin(), src.strides().end()};
for (int i = 0; i < src.ndim(); i++) {
strides[i] *= src.itemsize();
}
// Pack the capsule with the array
py::array out(dt, src.shape(), strides, src.data<T>(), freeWhenDone);
// Mark array as read-only
py::detail::array_proxy(out.ptr())->flags &=
~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
// Return array
return py::array(dt, src.shape(), strides, src.data<T>(), out);
}
py::array mlx_array_to_np(const array& src) {
// Eval if not already evaled
if (!src.is_evaled()) {
eval({src}, src.is_tracer());
}
switch (src.dtype()) {
case bool_:
return mlx_array_to_np_t<bool>(src);
case uint8:
return mlx_array_to_np_t<uint8_t>(src);
case uint16:
return mlx_array_to_np_t<uint16_t>(src);
case uint32:
return mlx_array_to_np_t<uint32_t>(src);
case uint64:
return mlx_array_to_np_t<uint64_t>(src);
case int8:
return mlx_array_to_np_t<int8_t>(src);
case int16:
return mlx_array_to_np_t<int16_t>(src);
case int32:
return mlx_array_to_np_t<int32_t>(src);
case int64:
return mlx_array_to_np_t<int64_t>(src);
case float16:
return mlx_array_to_np_t<float16_t>(src, py::dtype("float16"));
case float32:
return mlx_array_to_np_t<float>(src);
case bfloat16: {
auto a = astype(src, float32);
eval({a}, src.is_tracer());
return mlx_array_to_np_t<float>(a);
}
case complex64:
return mlx_array_to_np_t<complex64_t>(src, py::dtype("complex64"));
}
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Numpy -> MLX // Numpy -> MLX
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -479,6 +377,61 @@ array np_array_to_mlx(py::array np_array, std::optional<Dtype> dtype) {
} }
} }
///////////////////////////////////////////////////////////////////////////////
// Python Buffer Protocol (MLX -> Numpy)
///////////////////////////////////////////////////////////////////////////////
std::optional<std::string> buffer_format(const array& a) {
// https://docs.python.org/3.10/library/struct.html#format-characters
switch (a.dtype()) {
case bool_:
return pybind11::format_descriptor<bool>::format();
case uint8:
return pybind11::format_descriptor<uint8_t>::format();
case uint16:
return pybind11::format_descriptor<uint16_t>::format();
case uint32:
return pybind11::format_descriptor<uint32_t>::format();
case uint64:
return pybind11::format_descriptor<uint64_t>::format();
case int8:
return pybind11::format_descriptor<int8_t>::format();
case int16:
return pybind11::format_descriptor<int16_t>::format();
case int32:
return pybind11::format_descriptor<int32_t>::format();
case int64:
return pybind11::format_descriptor<int64_t>::format();
case float16:
// https://github.com/pybind/pybind11/issues/4998
return "e";
case float32: {
return pybind11::format_descriptor<float>::format();
}
case bfloat16:
// not supported by python buffer protocol or numpy.
// musst be null according to
// https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT
return {};
case complex64:
return pybind11::format_descriptor<std::complex<float>>::format();
default: {
std::ostringstream os;
os << "bad dtype: " << a.dtype();
throw std::runtime_error(os.str());
}
}
}
std::vector<size_t> buffer_strides(const array& a) {
std::vector<size_t> py_strides;
py_strides.reserve(a.strides().size());
for (const size_t stride : a.strides()) {
py_strides.push_back(stride * a.itemsize());
}
return py_strides;
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Module // Module
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -546,7 +499,10 @@ void init_array(py::module_& m) {
m.attr("complex64") = py::cast(complex64); m.attr("complex64") = py::cast(complex64);
auto array_class = py::class_<array>( auto array_class = py::class_<array>(
m, "array", R"pbdoc(An N-dimensional array object.)pbdoc"); m,
"array",
R"pbdoc(An N-dimensional array object.)pbdoc",
py::buffer_protocol());
{ {
py::options options; py::options options;
@ -564,6 +520,19 @@ void init_array(py::module_& m) {
} }
array_class array_class
.def_buffer([](array& a) {
// Eval if not already evaled
if (!a.is_evaled()) {
eval({a}, a.is_tracer());
}
return pybind11::buffer_info(
a.data<void>(),
a.itemsize(),
buffer_format(a).value_or(nullptr),
a.ndim(),
a.shape(),
buffer_strides(a));
})
.def_property_readonly( .def_property_readonly(
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc") "size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
.def_property_readonly( .def_property_readonly(
@ -620,7 +589,6 @@ void init_array(py::module_& m) {
The value type of the list corresponding to the last dimension is either The value type of the list corresponding to the last dimension is either
``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array. ``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array.
)pbdoc") )pbdoc")
.def("__array__", &mlx_array_to_np)
.def( .def(
"astype", "astype",
&astype, &astype,

View File

@ -2,12 +2,20 @@
import operator import operator
import unittest import unittest
import weakref
from itertools import permutations from itertools import permutations
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
import numpy as np import numpy as np
try:
import tensorflow as tf
has_tf = True
except ImportError as e:
has_tf = False
class TestVersion(mlx_tests.MLXTestCase): class TestVersion(mlx_tests.MLXTestCase):
def test_version(self): def test_version(self):
@ -1100,7 +1108,7 @@ class TestArray(mlx_tests.MLXTestCase):
# Check that we get read-only array that does not own the underlying data # Check that we get read-only array that does not own the underlying data
self.assertFalse(a_np.flags.owndata) self.assertFalse(a_np.flags.owndata)
self.assertFalse(a_np.flags.writeable) self.assertTrue(a_np.flags.writeable)
# Check contents # Check contents
self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np)) self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np))
@ -1109,6 +1117,157 @@ class TestArray(mlx_tests.MLXTestCase):
# Check strides # Check strides
self.assertSequenceEqual(b_np.strides, (0, 8, 4)) self.assertSequenceEqual(b_np.strides, (0, 8, 4))
def test_np_array_conversion_copies_by_default(self):
a_mx = mx.ones((2, 2))
a_np = np.array(a_mx)
self.assertTrue(a_np.flags.owndata)
self.assertTrue(a_np.flags.writeable)
def test_buffer_protocol(self):
dtypes_list = [
(mx.bool_, np.bool_, None),
(mx.uint8, np.uint8, np.iinfo),
(mx.uint16, np.uint16, np.iinfo),
(mx.uint32, np.uint32, np.iinfo),
(mx.uint64, np.uint64, np.iinfo),
(mx.int8, np.int8, np.iinfo),
(mx.int16, np.int16, np.iinfo),
(mx.int32, np.int32, np.iinfo),
(mx.int64, np.int64, np.iinfo),
(mx.float16, np.float16, np.finfo),
(mx.float32, np.float32, np.finfo),
(mx.complex64, np.complex64, np.finfo),
]
for mlx_dtype, np_dtype, info_fn in dtypes_list:
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
if info_fn is not None:
info = info_fn(np_dtype)
a_np[0, 0] = info.min
a_np[0, 1] = info.max
a_mx = mx.array(a_np)
for f in [lambda x: x, lambda x: x.T]:
mv_mx = memoryview(f(a_mx))
mv_np = memoryview(f(a_np))
self.assertEqual(mv_mx.strides, mv_np.strides, f"{mlx_dtype}{np_dtype}")
self.assertEqual(mv_mx.shape, mv_np.shape, f"{mlx_dtype}{np_dtype}")
# correct buffer format for 8 byte (unsigned) 'long long' is Q/q, see
# https://docs.python.org/3.10/library/struct.html#format-characters
# numpy returns L/l, as 'long' is equivalent to 'long long' on 64bit machines, so q and l are equivalent
# see https://github.com/pybind/pybind11/issues/1908
if np_dtype == np.uint64:
self.assertEqual(mv_mx.format, "Q", f"{mlx_dtype}{np_dtype}")
elif np_dtype == np.int64:
self.assertEqual(mv_mx.format, "q", f"{mlx_dtype}{np_dtype}")
else:
self.assertEqual(
mv_mx.format, mv_np.format, f"{mlx_dtype}{np_dtype}"
)
self.assertFalse(mv_mx.readonly)
back_to_npy = np.array(mv_mx, copy=False)
self.assertEqualArray(
back_to_npy,
f(a_np),
atol=0,
rtol=0,
msg=f"{mlx_dtype}{np_dtype}",
)
# extra test for bfloat16, which is not numpy convertible
a_mx = mx.random.uniform(low=0, high=100, shape=(3, 4), dtype=mx.bfloat16)
mv_mx = memoryview(a_mx)
self.assertEqual(mv_mx.strides, (8, 2))
self.assertEqual(mv_mx.shape, (3, 4))
self.assertEqual(mv_mx.format, "")
with self.assertRaises(RuntimeError) as cm:
np.array(a_mx)
e = cm.exception
self.assertTrue("Item size 2 for PEP 3118 buffer format string" in str(e))
def test_buffer_protocol_ref_counting(self):
a = mx.arange(3)
wr = weakref.ref(a)
self.assertIsNotNone(wr())
mv = memoryview(a)
a = None
self.assertIsNotNone(wr())
mv = None
self.assertIsNone(wr())
def test_array_view_ref_counting(self):
a = mx.arange(3)
wr = weakref.ref(a)
self.assertIsNotNone(wr())
a_np = np.array(a, copy=False)
a = None
self.assertIsNotNone(wr())
a_np = None
self.assertIsNone(wr())
@unittest.skipIf(not has_tf, "requires TensorFlow")
def test_buffer_protocol_tf(self):
dtypes_list = [
(
mx.bool_,
tf.bool,
np.bool_,
),
(
mx.uint8,
tf.uint8,
np.uint8,
),
(
mx.uint16,
tf.uint16,
np.uint16,
),
(
mx.uint32,
tf.uint32,
np.uint32,
),
(mx.uint64, tf.uint64, np.uint64),
(mx.int8, tf.int8, np.int8),
(mx.int16, tf.int16, np.int16),
(mx.int32, tf.int32, np.int32),
(mx.int64, tf.int64, np.int64),
(mx.float16, tf.float16, np.float16),
(mx.float32, tf.float32, np.float32),
(
mx.complex64,
tf.complex64,
np.complex64,
),
]
for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
a_tf = tf.constant(a_np, dtype=tf_dtype)
a_mx = mx.array(a_tf)
for f in [
lambda x: x,
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,
]:
mv_mx = memoryview(f(a_mx))
mv_tf = memoryview(f(a_tf))
if (mv_mx.c_contiguous and mv_tf.c_contiguous) or (
mv_mx.f_contiguous and mv_tf.f_contiguous
):
self.assertEqual(
mv_mx.strides, mv_tf.strides, f"{mlx_dtype}{tf_dtype}"
)
self.assertEqual(mv_mx.shape, mv_tf.shape, f"{mlx_dtype}{tf_dtype}")
self.assertFalse(mv_mx.readonly)
back_to_npy = np.array(mv_mx)
self.assertEqualArray(
back_to_npy,
f(a_tf),
atol=0,
rtol=0,
msg=f"{mlx_dtype}{tf_dtype}",
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -222,6 +222,8 @@ TEST_CASE("test array types") {
// complex64 // complex64
{ {
CHECK_EQ(sizeof(complex64_t), sizeof(std::complex<float>));
complex64_t v = {1.0f, 1.0f}; complex64_t v = {1.0f, 1.0f};
array x(v); array x(v);
CHECK_EQ(x.dtype(), complex64); CHECK_EQ(x.dtype(), complex64);