mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Pickle with NumPy arrays
This commit is contained in:
parent
c02602a4a1
commit
6b6b4f0a5f
@ -458,6 +458,21 @@ std::vector<size_t> buffer_strides(const array& a) {
|
||||
return py_strides;
|
||||
}
|
||||
|
||||
py::buffer_info buffer_info(array& a) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
py::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
}
|
||||
return pybind11::buffer_info(
|
||||
a.data<void>(),
|
||||
a.itemsize(),
|
||||
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
|
||||
// std::string which can't be null
|
||||
a.ndim(),
|
||||
a.shape(),
|
||||
buffer_strides(a));
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Module
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@ -647,21 +662,7 @@ void init_array(py::module_& m) {
|
||||
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
|
||||
|
||||
array_class
|
||||
.def_buffer([](array& a) {
|
||||
// Eval if not already evaled
|
||||
if (!a.is_evaled()) {
|
||||
py::gil_scoped_release nogil;
|
||||
a.eval();
|
||||
}
|
||||
return pybind11::buffer_info(
|
||||
a.data<void>(),
|
||||
a.itemsize(),
|
||||
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
|
||||
// std::string which can't be null
|
||||
a.ndim(),
|
||||
a.shape(),
|
||||
buffer_strides(a));
|
||||
})
|
||||
.def_buffer([](array& a) -> py::buffer_info { return buffer_info(a); })
|
||||
.def_property_readonly(
|
||||
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
||||
.def_property_readonly(
|
||||
@ -775,17 +776,13 @@ void init_array(py::module_& m) {
|
||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||
.def(py::pickle(
|
||||
[](array& a) { // __getstate__
|
||||
return py::make_tuple(
|
||||
dtype_to_array_protocol(a.dtype()), tolist(a));
|
||||
return py::array(buffer_info(a));
|
||||
},
|
||||
[](py::tuple t) { // __setstate__
|
||||
if (t.size() != 2 or !py::isinstance<py::str>(t[0]) or
|
||||
!py::isinstance<py::list>(t[1]))
|
||||
throw std::invalid_argument(
|
||||
"Invalide state for __setstate__. Expected a tuple of length 2 with a string and a list as elements.");
|
||||
|
||||
return array_from_list(
|
||||
t[1], dtype_from_array_protocol(t[0].cast<std::string>()));
|
||||
[](py::array npa) { // __setstate__
|
||||
if (not py::isinstance<py::array>(npa)) {
|
||||
throw std::runtime_error("Invalid state!");
|
||||
}
|
||||
return np_array_to_mlx(npa, std::nullopt);
|
||||
}))
|
||||
.def("__copy__", [](const array& self) { return array(self); })
|
||||
.def(
|
||||
|
@ -1,9 +1,10 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import operator
|
||||
import pickle
|
||||
import unittest
|
||||
import weakref
|
||||
from copy import copy, deepcopy
|
||||
from itertools import permutations
|
||||
|
||||
import mlx.core as mx
|
||||
@ -670,17 +671,15 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
mx.uint64,
|
||||
mx.float16,
|
||||
mx.float32,
|
||||
mx.bfloat16,
|
||||
# mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
import pickle
|
||||
|
||||
for dtype in dtypes:
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||
state = pickle.dumps(x)
|
||||
y = pickle.loads(state)
|
||||
self.assertEqualArray(x, y)
|
||||
self.assertEqual(x.dtype, y.dtype)
|
||||
self.assertEqualArray(y, x)
|
||||
|
||||
def test_array_copy(self):
|
||||
dtypes = [
|
||||
@ -698,16 +697,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
mx.complex64,
|
||||
]
|
||||
|
||||
from copy import copy, deepcopy
|
||||
|
||||
for copy_function in [copy, deepcopy]:
|
||||
for dtype in dtypes:
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||
y = copy_function(x)
|
||||
self.assertEqualArray(x, y)
|
||||
self.assertEqualArray(y, x)
|
||||
|
||||
y -= 1
|
||||
self.assertEqualArray(x - 1, y)
|
||||
self.assertEqualArray(y, x - 1)
|
||||
|
||||
def test_indexing(self):
|
||||
# Basic content check, slice indexing
|
||||
|
Loading…
Reference in New Issue
Block a user