Implemented pickling and copy for Python arrays(#300 & #367)

This commit is contained in:
Luca Arnaboldi 2024-02-20 09:25:07 +01:00
parent 146bd69470
commit 8ba3625a40
3 changed files with 74 additions and 0 deletions

View File

@ -13,6 +13,9 @@ MLX was developed with contributions from the following individuals:
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops.
Implemented pickling, copy and deepcopy for Python arrays.
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
</a>

View File

@ -774,6 +774,25 @@ void init_array(py::module_& m) {
return a.shape(0);
})
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
.def(py::pickle(
[](array& a) { // __getstate__
return py::make_tuple(
dtype_to_array_protocol(a.dtype()), tolist(a));
},
[](py::tuple t) { // __setstate__
if (t.size() != 2 or !py::isinstance<py::str>(t[0]) or
!py::isinstance<py::list>(t[1]))
throw std::invalid_argument(
"Invalide state for __setstate__. Expected a tuple of length 2 with a string and a list as elements.");
return array_from_list(
t[1], dtype_from_array_protocol(t[0].cast<std::string>()));
}))
.def("__copy__", [](const array& self) { return array(self); })
.def(
"__deepcopy__",
[](const array& self, py::dict) { return array(self); },
"memo"_a)
.def(
"__add__",
[](const array& a, const ScalarOrArray v) {

View File

@ -649,6 +649,58 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(y.tolist(), [3.0, 4.0])
self.assertEqual(z.tolist(), [5.0, 6.0])
def test_array_pickle(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
# mx.float16,
mx.float32,
# mx.bfloat16,
mx.complex64,
]
import pickle
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
state = pickle.dumps(x)
y = pickle.loads(state)
self.assertEqualArray(x, y)
self.assertEqual(x.dtype, y.dtype)
def test_array_copy(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
# mx.float16,
mx.float32,
# mx.bfloat16,
mx.complex64,
]
from copy import copy, deepcopy
for copy_function in [copy, deepcopy]:
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
y = copy_function(x)
self.assertEqualArray(x, y)
y -= 1
print(x, y)
self.assertEqualArray(x - 1, y)
def test_indexing(self):
# Basic content check, slice indexing
a_npy = np.arange(64, dtype=np.float32)