From 1331fa19f6393daf8ab0e0e62ad681fbb97a1fa4 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Sat, 6 Jan 2024 00:58:33 +0100 Subject: [PATCH] Make array conform to the Python Buffer Protocol (#323) --- .circleci/config.yml | 1 + README.md | 2 +- docs/src/index.rst | 7 +- docs/src/usage/numpy.rst | 103 ++++++++++++++ docs/src/{ => usage}/quick_start.rst | 0 docs/src/{ => usage}/unified_memory.rst | 0 docs/src/{ => usage}/using_streams.rst | 0 python/src/array.cpp | 176 ++++++++++-------------- python/tests/test_array.py | 161 +++++++++++++++++++++- tests/array_tests.cpp | 2 + 10 files changed, 343 insertions(+), 109 deletions(-) create mode 100644 docs/src/usage/numpy.rst rename docs/src/{ => usage}/quick_start.rst (100%) rename docs/src/{ => usage}/unified_memory.rst (100%) rename docs/src/{ => usage}/using_streams.rst (100%) diff --git a/.circleci/config.yml b/.circleci/config.yml index 8d7d1627a..13cebce75 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -62,6 +62,7 @@ jobs: pip install --upgrade pybind11[global] pip install numpy pip install torch + pip install tensorflow pip install unittest-xml-reporting - run: name: Build python package diff --git a/README.md b/README.md index bb68d38cb..047bf1041 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ variety of examples, including: ## Quickstart See the [quick start -guide](https://ml-explore.github.io/mlx/build/html/quick_start.html) +guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html) in the documentation. ## Installation diff --git a/docs/src/index.rst b/docs/src/index.rst index 9f0445a18..f1fe468ca 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -35,9 +35,10 @@ are the CPU and GPU. :caption: Usage :maxdepth: 1 - quick_start - unified_memory - using_streams + usage/quick_start + usage/unified_memory + usage/using_streams + usage/numpy .. toctree:: :caption: Examples diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst new file mode 100644 index 000000000..ef075ad0c --- /dev/null +++ b/docs/src/usage/numpy.rst @@ -0,0 +1,103 @@ +.. _numpy: + +Conversion to NumPy and Other Frameworks +======================================== + +MLX array implements the `Python Buffer Protocol `_. +Let's convert an array to NumPy and back. + +.. code-block:: python + + import mlx.core as mx + import numpy as np + + a = mx.arange(3) + b = np.array(a) # copy of a + c = mx.array(b) # copy of b + +.. note:: + + Since NumPy does not support ``bfloat16`` arrays, you will need to convert to ``float16`` or ``float32`` first: + ``np.array(a.astype(mx.float32))``. + Otherwise, you will receive an error like: ``Item size 2 for PEP 3118 buffer format string does not match the dtype V item size 0.`` + +By default, NumPy copies data to a new array. This can be prevented by creating an array view: + +.. code-block:: python + + a = mx.arange(3) + a_view = np.array(a, copy=False) + print(a_view.flags.owndata) # False + a_view[0] = 1 + print(a[0].item()) # 1 + +A NumPy array view is a normal NumPy array, except that it does not own its memory. +This means writing to the view is reflected in the original array. + +While this is quite powerful to prevent copying arrays, it should be noted that external changes to the memory of arrays cannot be reflected in gradients. + +Let's demonstrate this in an example: + +.. code-block:: python + + def f(x): + x_view = np.array(x, copy=False) + x_view[:] *= x_view # modify memory without telling mx + return x.sum() + + x = mx.array([3.0]) + y, df = mx.value_and_grad(f)(x) + print("f(x) = x² =", y.item()) # 9.0 + print("f'(x) = 2x !=", df.item()) # 1.0 + + +The function ``f`` indirectly modifies the array ``x`` through a memory view. +However, this modification is not reflected in the gradient, as seen in the last line outputting ``1.0``, +representing the gradient of the sum operation alone. +The squaring of ``x`` occurs externally to MLX, meaning that no gradient is incorporated. +It's important to note that a similar issue arises during array conversion and copying. +For instance, a function defined as ``mx.array(np.array(x)**2).sum()`` would also result in an incorrect gradient, +even though no in-place operations on MLX memory are executed. + +PyTorch +------- + +PyTorch supports the buffer protocol, but it requires an explicit :obj:`memoryview`. + +.. code-block:: python + + import mlx.core as mx + import torch + + a = mx.arange(3) + b = torch.tensor(memoryview(a)) + c = mx.array(b.numpy()) + +Conversion from PyTorch tensors back to arrays must be done via intermediate NumPy arrays with ``numpy()``. + +JAX +--- +JAX fully supports the buffer protocol. + +.. code-block:: python + + import mlx.core as mx + import jax.numpy as jnp + + a = mx.arange(3) + b = jnp.array(a) + c = mx.array(b) + +TensorFlow +---------- + +TensorFlow supports the buffer protocol, but it requires an explicit :obj:`memoryview`. + +.. code-block:: python + + import mlx.core as mx + import tensorflow as tf + + a = mx.arange(3) + b = tf.constant(memoryview(a)) + c = mx.array(b) diff --git a/docs/src/quick_start.rst b/docs/src/usage/quick_start.rst similarity index 100% rename from docs/src/quick_start.rst rename to docs/src/usage/quick_start.rst diff --git a/docs/src/unified_memory.rst b/docs/src/usage/unified_memory.rst similarity index 100% rename from docs/src/unified_memory.rst rename to docs/src/usage/unified_memory.rst diff --git a/docs/src/using_streams.rst b/docs/src/usage/using_streams.rst similarity index 100% rename from docs/src/using_streams.rst rename to docs/src/usage/using_streams.rst diff --git a/python/src/array.cpp b/python/src/array.cpp index 5e57eb4c2..5ce09dd90 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -278,108 +278,6 @@ array array_from_list(T pl, std::optional dtype) { return stack(arrays); } -/////////////////////////////////////////////////////////////////////////////// -// MLX -> Numpy -/////////////////////////////////////////////////////////////////////////////// - -size_t elem_to_loc( - int elem, - const std::vector& shape, - const std::vector& strides) { - size_t loc = 0; - for (int i = shape.size() - 1; i >= 0; --i) { - auto q_and_r = ldiv(elem, shape[i]); - loc += q_and_r.rem * strides[i]; - elem = q_and_r.quot; - } - return loc; -} - -struct PyArrayPayload { - array a; -}; - -template -py::array_t mlx_array_to_np_t(const array& src) { - // Let py::capsule hold onto a copy of the array which holds a shared ptr to - // the data - const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) { - delete reinterpret_cast(payload); - }); - // Collect strides - std::vector strides{src.strides().begin(), src.strides().end()}; - for (int i = 0; i < src.ndim(); i++) { - strides[i] *= src.itemsize(); - } - // Pack the capsule with the array - py::array_t out(src.shape(), strides, src.data(), freeWhenDone); - // Mark array as read-only - py::detail::array_proxy(out.ptr())->flags &= - ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_; - // Return array - return py::array_t(src.shape(), strides, src.data(), out); -} - -template -py::array mlx_array_to_np_t(const array& src, const py::dtype& dt) { - // Let py::capsule hold onto a copy of the array which holds a shared ptr to - // the data - const py::capsule freeWhenDone(new PyArrayPayload({src}), [](void* payload) { - delete reinterpret_cast(payload); - }); - // Collect strides - std::vector strides{src.strides().begin(), src.strides().end()}; - for (int i = 0; i < src.ndim(); i++) { - strides[i] *= src.itemsize(); - } - // Pack the capsule with the array - py::array out(dt, src.shape(), strides, src.data(), freeWhenDone); - // Mark array as read-only - py::detail::array_proxy(out.ptr())->flags &= - ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_; - // Return array - return py::array(dt, src.shape(), strides, src.data(), out); -} - -py::array mlx_array_to_np(const array& src) { - // Eval if not already evaled - if (!src.is_evaled()) { - eval({src}, src.is_tracer()); - } - - switch (src.dtype()) { - case bool_: - return mlx_array_to_np_t(src); - case uint8: - return mlx_array_to_np_t(src); - case uint16: - return mlx_array_to_np_t(src); - case uint32: - return mlx_array_to_np_t(src); - case uint64: - return mlx_array_to_np_t(src); - case int8: - return mlx_array_to_np_t(src); - case int16: - return mlx_array_to_np_t(src); - case int32: - return mlx_array_to_np_t(src); - case int64: - return mlx_array_to_np_t(src); - case float16: - return mlx_array_to_np_t(src, py::dtype("float16")); - case float32: - return mlx_array_to_np_t(src); - case bfloat16: { - auto a = astype(src, float32); - eval({a}, src.is_tracer()); - return mlx_array_to_np_t(a); - } - case complex64: - return mlx_array_to_np_t(src, py::dtype("complex64")); - } -} - /////////////////////////////////////////////////////////////////////////////// // Numpy -> MLX /////////////////////////////////////////////////////////////////////////////// @@ -479,6 +377,61 @@ array np_array_to_mlx(py::array np_array, std::optional dtype) { } } +/////////////////////////////////////////////////////////////////////////////// +// Python Buffer Protocol (MLX -> Numpy) +/////////////////////////////////////////////////////////////////////////////// + +std::optional buffer_format(const array& a) { + // https://docs.python.org/3.10/library/struct.html#format-characters + switch (a.dtype()) { + case bool_: + return pybind11::format_descriptor::format(); + case uint8: + return pybind11::format_descriptor::format(); + case uint16: + return pybind11::format_descriptor::format(); + case uint32: + return pybind11::format_descriptor::format(); + case uint64: + return pybind11::format_descriptor::format(); + case int8: + return pybind11::format_descriptor::format(); + case int16: + return pybind11::format_descriptor::format(); + case int32: + return pybind11::format_descriptor::format(); + case int64: + return pybind11::format_descriptor::format(); + case float16: + // https://github.com/pybind/pybind11/issues/4998 + return "e"; + case float32: { + return pybind11::format_descriptor::format(); + } + case bfloat16: + // not supported by python buffer protocol or numpy. + // musst be null according to + // https://docs.python.org/3.10/c-api/buffer.html#c.PyBUF_FORMAT + return {}; + case complex64: + return pybind11::format_descriptor>::format(); + default: { + std::ostringstream os; + os << "bad dtype: " << a.dtype(); + throw std::runtime_error(os.str()); + } + } +} + +std::vector buffer_strides(const array& a) { + std::vector py_strides; + py_strides.reserve(a.strides().size()); + for (const size_t stride : a.strides()) { + py_strides.push_back(stride * a.itemsize()); + } + return py_strides; +} + /////////////////////////////////////////////////////////////////////////////// // Module /////////////////////////////////////////////////////////////////////////////// @@ -546,7 +499,10 @@ void init_array(py::module_& m) { m.attr("complex64") = py::cast(complex64); auto array_class = py::class_( - m, "array", R"pbdoc(An N-dimensional array object.)pbdoc"); + m, + "array", + R"pbdoc(An N-dimensional array object.)pbdoc", + py::buffer_protocol()); { py::options options; @@ -564,6 +520,19 @@ void init_array(py::module_& m) { } array_class + .def_buffer([](array& a) { + // Eval if not already evaled + if (!a.is_evaled()) { + eval({a}, a.is_tracer()); + } + return pybind11::buffer_info( + a.data(), + a.itemsize(), + buffer_format(a).value_or(nullptr), + a.ndim(), + a.shape(), + buffer_strides(a)); + }) .def_property_readonly( "size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc") .def_property_readonly( @@ -620,7 +589,6 @@ void init_array(py::module_& m) { The value type of the list corresponding to the last dimension is either ``bool``, ``int`` or ``float`` depending on the ``dtype`` of the array. )pbdoc") - .def("__array__", &mlx_array_to_np) .def( "astype", &astype, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 42d41a550..eee570920 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2,12 +2,20 @@ import operator import unittest +import weakref from itertools import permutations import mlx.core as mx import mlx_tests import numpy as np +try: + import tensorflow as tf + + has_tf = True +except ImportError as e: + has_tf = False + class TestVersion(mlx_tests.MLXTestCase): def test_version(self): @@ -1100,7 +1108,7 @@ class TestArray(mlx_tests.MLXTestCase): # Check that we get read-only array that does not own the underlying data self.assertFalse(a_np.flags.owndata) - self.assertFalse(a_np.flags.writeable) + self.assertTrue(a_np.flags.writeable) # Check contents self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np)) @@ -1109,6 +1117,157 @@ class TestArray(mlx_tests.MLXTestCase): # Check strides self.assertSequenceEqual(b_np.strides, (0, 8, 4)) + def test_np_array_conversion_copies_by_default(self): + a_mx = mx.ones((2, 2)) + a_np = np.array(a_mx) + self.assertTrue(a_np.flags.owndata) + self.assertTrue(a_np.flags.writeable) + + def test_buffer_protocol(self): + dtypes_list = [ + (mx.bool_, np.bool_, None), + (mx.uint8, np.uint8, np.iinfo), + (mx.uint16, np.uint16, np.iinfo), + (mx.uint32, np.uint32, np.iinfo), + (mx.uint64, np.uint64, np.iinfo), + (mx.int8, np.int8, np.iinfo), + (mx.int16, np.int16, np.iinfo), + (mx.int32, np.int32, np.iinfo), + (mx.int64, np.int64, np.iinfo), + (mx.float16, np.float16, np.finfo), + (mx.float32, np.float32, np.finfo), + (mx.complex64, np.complex64, np.finfo), + ] + + for mlx_dtype, np_dtype, info_fn in dtypes_list: + a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype) + if info_fn is not None: + info = info_fn(np_dtype) + a_np[0, 0] = info.min + a_np[0, 1] = info.max + a_mx = mx.array(a_np) + for f in [lambda x: x, lambda x: x.T]: + mv_mx = memoryview(f(a_mx)) + mv_np = memoryview(f(a_np)) + self.assertEqual(mv_mx.strides, mv_np.strides, f"{mlx_dtype}{np_dtype}") + self.assertEqual(mv_mx.shape, mv_np.shape, f"{mlx_dtype}{np_dtype}") + # correct buffer format for 8 byte (unsigned) 'long long' is Q/q, see + # https://docs.python.org/3.10/library/struct.html#format-characters + # numpy returns L/l, as 'long' is equivalent to 'long long' on 64bit machines, so q and l are equivalent + # see https://github.com/pybind/pybind11/issues/1908 + if np_dtype == np.uint64: + self.assertEqual(mv_mx.format, "Q", f"{mlx_dtype}{np_dtype}") + elif np_dtype == np.int64: + self.assertEqual(mv_mx.format, "q", f"{mlx_dtype}{np_dtype}") + else: + self.assertEqual( + mv_mx.format, mv_np.format, f"{mlx_dtype}{np_dtype}" + ) + self.assertFalse(mv_mx.readonly) + back_to_npy = np.array(mv_mx, copy=False) + self.assertEqualArray( + back_to_npy, + f(a_np), + atol=0, + rtol=0, + msg=f"{mlx_dtype}{np_dtype}", + ) + + # extra test for bfloat16, which is not numpy convertible + a_mx = mx.random.uniform(low=0, high=100, shape=(3, 4), dtype=mx.bfloat16) + mv_mx = memoryview(a_mx) + self.assertEqual(mv_mx.strides, (8, 2)) + self.assertEqual(mv_mx.shape, (3, 4)) + self.assertEqual(mv_mx.format, "") + with self.assertRaises(RuntimeError) as cm: + np.array(a_mx) + e = cm.exception + self.assertTrue("Item size 2 for PEP 3118 buffer format string" in str(e)) + + def test_buffer_protocol_ref_counting(self): + a = mx.arange(3) + wr = weakref.ref(a) + self.assertIsNotNone(wr()) + mv = memoryview(a) + a = None + self.assertIsNotNone(wr()) + mv = None + self.assertIsNone(wr()) + + def test_array_view_ref_counting(self): + a = mx.arange(3) + wr = weakref.ref(a) + self.assertIsNotNone(wr()) + a_np = np.array(a, copy=False) + a = None + self.assertIsNotNone(wr()) + a_np = None + self.assertIsNone(wr()) + + @unittest.skipIf(not has_tf, "requires TensorFlow") + def test_buffer_protocol_tf(self): + dtypes_list = [ + ( + mx.bool_, + tf.bool, + np.bool_, + ), + ( + mx.uint8, + tf.uint8, + np.uint8, + ), + ( + mx.uint16, + tf.uint16, + np.uint16, + ), + ( + mx.uint32, + tf.uint32, + np.uint32, + ), + (mx.uint64, tf.uint64, np.uint64), + (mx.int8, tf.int8, np.int8), + (mx.int16, tf.int16, np.int16), + (mx.int32, tf.int32, np.int32), + (mx.int64, tf.int64, np.int64), + (mx.float16, tf.float16, np.float16), + (mx.float32, tf.float32, np.float32), + ( + mx.complex64, + tf.complex64, + np.complex64, + ), + ] + + for mlx_dtype, tf_dtype, np_dtype in dtypes_list: + a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype) + a_tf = tf.constant(a_np, dtype=tf_dtype) + a_mx = mx.array(a_tf) + for f in [ + lambda x: x, + lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T, + ]: + mv_mx = memoryview(f(a_mx)) + mv_tf = memoryview(f(a_tf)) + if (mv_mx.c_contiguous and mv_tf.c_contiguous) or ( + mv_mx.f_contiguous and mv_tf.f_contiguous + ): + self.assertEqual( + mv_mx.strides, mv_tf.strides, f"{mlx_dtype}{tf_dtype}" + ) + self.assertEqual(mv_mx.shape, mv_tf.shape, f"{mlx_dtype}{tf_dtype}") + self.assertFalse(mv_mx.readonly) + back_to_npy = np.array(mv_mx) + self.assertEqualArray( + back_to_npy, + f(a_tf), + atol=0, + rtol=0, + msg=f"{mlx_dtype}{tf_dtype}", + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/array_tests.cpp b/tests/array_tests.cpp index 334b86115..080d53daa 100644 --- a/tests/array_tests.cpp +++ b/tests/array_tests.cpp @@ -222,6 +222,8 @@ TEST_CASE("test array types") { // complex64 { + CHECK_EQ(sizeof(complex64_t), sizeof(std::complex)); + complex64_t v = {1.0f, 1.0f}; array x(v); CHECK_EQ(x.dtype(), complex64);