diff --git a/python/src/array.cpp b/python/src/array.cpp index 27cdad726..8bca3ba3e 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -458,6 +458,21 @@ std::vector buffer_strides(const array& a) { return py_strides; } +py::buffer_info buffer_info(array& a) { + // Eval if not already evaled + if (!a.is_evaled()) { + py::gil_scoped_release nogil; + a.eval(); + } + return pybind11::buffer_info( + a.data(), + a.itemsize(), + buffer_format(a).value_or("B"), // we use "B" because pybind uses a + // std::string which can't be null + a.ndim(), + a.shape(), + buffer_strides(a)); +} /////////////////////////////////////////////////////////////////////////////// // Module /////////////////////////////////////////////////////////////////////////////// @@ -647,21 +662,7 @@ void init_array(py::module_& m) { .def("__iter__", [](const ArrayPythonIterator& it) { return it; }); array_class - .def_buffer([](array& a) { - // Eval if not already evaled - if (!a.is_evaled()) { - py::gil_scoped_release nogil; - a.eval(); - } - return pybind11::buffer_info( - a.data(), - a.itemsize(), - buffer_format(a).value_or("B"), // we use "B" because pybind uses a - // std::string which can't be null - a.ndim(), - a.shape(), - buffer_strides(a)); - }) + .def_buffer([](array& a) -> py::buffer_info { return buffer_info(a); }) .def_property_readonly( "size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc") .def_property_readonly( @@ -775,17 +776,13 @@ void init_array(py::module_& m) { .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) .def(py::pickle( [](array& a) { // __getstate__ - return py::make_tuple( - dtype_to_array_protocol(a.dtype()), tolist(a)); + return py::array(buffer_info(a)); }, - [](py::tuple t) { // __setstate__ - if (t.size() != 2 or !py::isinstance(t[0]) or - !py::isinstance(t[1])) - throw std::invalid_argument( - "Invalide state for __setstate__. Expected a tuple of length 2 with a string and a list as elements."); - - return array_from_list( - t[1], dtype_from_array_protocol(t[0].cast())); + [](py::array npa) { // __setstate__ + if (not py::isinstance(npa)) { + throw std::runtime_error("Invalid state!"); + } + return np_array_to_mlx(npa, std::nullopt); })) .def("__copy__", [](const array& self) { return array(self); }) .def( diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 8b034e0a5..4409baa51 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1,9 +1,10 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2024 Apple Inc. import operator import pickle import unittest import weakref +from copy import copy, deepcopy from itertools import permutations import mlx.core as mx @@ -670,17 +671,15 @@ class TestArray(mlx_tests.MLXTestCase): mx.uint64, mx.float16, mx.float32, - mx.bfloat16, + # mx.bfloat16, mx.complex64, ] - import pickle for dtype in dtypes: x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) state = pickle.dumps(x) y = pickle.loads(state) - self.assertEqualArray(x, y) - self.assertEqual(x.dtype, y.dtype) + self.assertEqualArray(y, x) def test_array_copy(self): dtypes = [ @@ -698,16 +697,14 @@ class TestArray(mlx_tests.MLXTestCase): mx.complex64, ] - from copy import copy, deepcopy - for copy_function in [copy, deepcopy]: for dtype in dtypes: x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) y = copy_function(x) - self.assertEqualArray(x, y) + self.assertEqualArray(y, x) y -= 1 - self.assertEqualArray(x - 1, y) + self.assertEqualArray(y, x - 1) def test_indexing(self): # Basic content check, slice indexing