diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md
index 18a8c5599..12fdd47bc 100644
--- a/ACKNOWLEDGMENTS.md
+++ b/ACKNOWLEDGMENTS.md
@@ -13,6 +13,9 @@ MLX was developed with contributions from the following individuals:
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
+- Luca Arnaboldi: Added `Ceil` and `Floor` ops.
+ Implemented pickling, copy and deepcopy for Python arrays.
+
diff --git a/python/src/array.cpp b/python/src/array.cpp
index 4395d50e6..4d8c22748 100644
--- a/python/src/array.cpp
+++ b/python/src/array.cpp
@@ -774,6 +774,25 @@ void init_array(py::module_& m) {
return a.shape(0);
})
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
+ .def(py::pickle(
+ [](array& a) { // __getstate__
+ return py::make_tuple(
+ dtype_to_array_protocol(a.dtype()), tolist(a));
+ },
+ [](py::tuple t) { // __setstate__
+ if (t.size() != 2 or !py::isinstance(t[0]) or
+ !py::isinstance(t[1]))
+ throw std::invalid_argument(
+ "Invalide state for __setstate__. Expected a tuple of length 2 with a string and a list as elements.");
+
+ return array_from_list(
+ t[1], dtype_from_array_protocol(t[0].cast()));
+ }))
+ .def("__copy__", [](const array& self) { return array(self); })
+ .def(
+ "__deepcopy__",
+ [](const array& self, py::dict) { return array(self); },
+ "memo"_a)
.def(
"__add__",
[](const array& a, const ScalarOrArray v) {
diff --git a/python/tests/test_array.py b/python/tests/test_array.py
index 507675d6e..40d0923f4 100644
--- a/python/tests/test_array.py
+++ b/python/tests/test_array.py
@@ -649,6 +649,58 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(y.tolist(), [3.0, 4.0])
self.assertEqual(z.tolist(), [5.0, 6.0])
+ def test_array_pickle(self):
+ dtypes = [
+ mx.int8,
+ mx.int16,
+ mx.int32,
+ mx.int64,
+ mx.uint8,
+ mx.uint16,
+ mx.uint32,
+ mx.uint64,
+ # mx.float16,
+ mx.float32,
+ # mx.bfloat16,
+ mx.complex64,
+ ]
+ import pickle
+
+ for dtype in dtypes:
+ x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
+ state = pickle.dumps(x)
+ y = pickle.loads(state)
+ self.assertEqualArray(x, y)
+ self.assertEqual(x.dtype, y.dtype)
+
+ def test_array_copy(self):
+ dtypes = [
+ mx.int8,
+ mx.int16,
+ mx.int32,
+ mx.int64,
+ mx.uint8,
+ mx.uint16,
+ mx.uint32,
+ mx.uint64,
+ # mx.float16,
+ mx.float32,
+ # mx.bfloat16,
+ mx.complex64,
+ ]
+
+ from copy import copy, deepcopy
+
+ for copy_function in [copy, deepcopy]:
+ for dtype in dtypes:
+ x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
+ y = copy_function(x)
+ self.assertEqualArray(x, y)
+
+ y -= 1
+ print(x, y)
+ self.assertEqualArray(x - 1, y)
+
def test_indexing(self):
# Basic content check, slice indexing
a_npy = np.arange(64, dtype=np.float32)