Pickle with NumPy arrays

This commit is contained in:
Luca Arnaboldi 2024-03-04 13:00:18 +01:00
parent c02602a4a1
commit 6b6b4f0a5f
2 changed files with 28 additions and 34 deletions

View File

@ -458,6 +458,21 @@ std::vector<size_t> buffer_strides(const array& a) {
return py_strides; return py_strides;
} }
py::buffer_info buffer_info(array& a) {
// Eval if not already evaled
if (!a.is_evaled()) {
py::gil_scoped_release nogil;
a.eval();
}
return pybind11::buffer_info(
a.data<void>(),
a.itemsize(),
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
// std::string which can't be null
a.ndim(),
a.shape(),
buffer_strides(a));
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Module // Module
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -647,21 +662,7 @@ void init_array(py::module_& m) {
.def("__iter__", [](const ArrayPythonIterator& it) { return it; }); .def("__iter__", [](const ArrayPythonIterator& it) { return it; });
array_class array_class
.def_buffer([](array& a) { .def_buffer([](array& a) -> py::buffer_info { return buffer_info(a); })
// Eval if not already evaled
if (!a.is_evaled()) {
py::gil_scoped_release nogil;
a.eval();
}
return pybind11::buffer_info(
a.data<void>(),
a.itemsize(),
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
// std::string which can't be null
a.ndim(),
a.shape(),
buffer_strides(a));
})
.def_property_readonly( .def_property_readonly(
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc") "size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
.def_property_readonly( .def_property_readonly(
@ -775,17 +776,13 @@ void init_array(py::module_& m) {
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
.def(py::pickle( .def(py::pickle(
[](array& a) { // __getstate__ [](array& a) { // __getstate__
return py::make_tuple( return py::array(buffer_info(a));
dtype_to_array_protocol(a.dtype()), tolist(a));
}, },
[](py::tuple t) { // __setstate__ [](py::array npa) { // __setstate__
if (t.size() != 2 or !py::isinstance<py::str>(t[0]) or if (not py::isinstance<py::array>(npa)) {
!py::isinstance<py::list>(t[1])) throw std::runtime_error("Invalid state!");
throw std::invalid_argument( }
"Invalide state for __setstate__. Expected a tuple of length 2 with a string and a list as elements."); return np_array_to_mlx(npa, std::nullopt);
return array_from_list(
t[1], dtype_from_array_protocol(t[0].cast<std::string>()));
})) }))
.def("__copy__", [](const array& self) { return array(self); }) .def("__copy__", [](const array& self) { return array(self); })
.def( .def(

View File

@ -1,9 +1,10 @@
# Copyright © 2023 Apple Inc. # Copyright © 2024 Apple Inc.
import operator import operator
import pickle import pickle
import unittest import unittest
import weakref import weakref
from copy import copy, deepcopy
from itertools import permutations from itertools import permutations
import mlx.core as mx import mlx.core as mx
@ -670,17 +671,15 @@ class TestArray(mlx_tests.MLXTestCase):
mx.uint64, mx.uint64,
mx.float16, mx.float16,
mx.float32, mx.float32,
mx.bfloat16, # mx.bfloat16,
mx.complex64, mx.complex64,
] ]
import pickle
for dtype in dtypes: for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
state = pickle.dumps(x) state = pickle.dumps(x)
y = pickle.loads(state) y = pickle.loads(state)
self.assertEqualArray(x, y) self.assertEqualArray(y, x)
self.assertEqual(x.dtype, y.dtype)
def test_array_copy(self): def test_array_copy(self):
dtypes = [ dtypes = [
@ -698,16 +697,14 @@ class TestArray(mlx_tests.MLXTestCase):
mx.complex64, mx.complex64,
] ]
from copy import copy, deepcopy
for copy_function in [copy, deepcopy]: for copy_function in [copy, deepcopy]:
for dtype in dtypes: for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
y = copy_function(x) y = copy_function(x)
self.assertEqualArray(x, y) self.assertEqualArray(y, x)
y -= 1 y -= 1
self.assertEqualArray(x - 1, y) self.assertEqualArray(y, x - 1)
def test_indexing(self): def test_indexing(self):
# Basic content check, slice indexing # Basic content check, slice indexing