diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 18a8c5599..12fdd47bc 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -13,6 +13,9 @@ MLX was developed with contributions from the following individuals: - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. +- Luca Arnaboldi: Added `Ceil` and `Floor` ops. + Implemented pickling, copy and deepcopy for Python arrays. + diff --git a/python/src/array.cpp b/python/src/array.cpp index 4395d50e6..4d8c22748 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -774,6 +774,25 @@ void init_array(py::module_& m) { return a.shape(0); }) .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) + .def(py::pickle( + [](array& a) { // __getstate__ + return py::make_tuple( + dtype_to_array_protocol(a.dtype()), tolist(a)); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 2 or !py::isinstance(t[0]) or + !py::isinstance(t[1])) + throw std::invalid_argument( + "Invalide state for __setstate__. Expected a tuple of length 2 with a string and a list as elements."); + + return array_from_list( + t[1], dtype_from_array_protocol(t[0].cast())); + })) + .def("__copy__", [](const array& self) { return array(self); }) + .def( + "__deepcopy__", + [](const array& self, py::dict) { return array(self); }, + "memo"_a) .def( "__add__", [](const array& a, const ScalarOrArray v) { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 507675d6e..40d0923f4 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -649,6 +649,58 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual(y.tolist(), [3.0, 4.0]) self.assertEqual(z.tolist(), [5.0, 6.0]) + def test_array_pickle(self): + dtypes = [ + mx.int8, + mx.int16, + mx.int32, + mx.int64, + mx.uint8, + mx.uint16, + mx.uint32, + mx.uint64, + # mx.float16, + mx.float32, + # mx.bfloat16, + mx.complex64, + ] + import pickle + + for dtype in dtypes: + x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) + state = pickle.dumps(x) + y = pickle.loads(state) + self.assertEqualArray(x, y) + self.assertEqual(x.dtype, y.dtype) + + def test_array_copy(self): + dtypes = [ + mx.int8, + mx.int16, + mx.int32, + mx.int64, + mx.uint8, + mx.uint16, + mx.uint32, + mx.uint64, + # mx.float16, + mx.float32, + # mx.bfloat16, + mx.complex64, + ] + + from copy import copy, deepcopy + + for copy_function in [copy, deepcopy]: + for dtype in dtypes: + x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) + y = copy_function(x) + self.assertEqualArray(x, y) + + y -= 1 + print(x, y) + self.assertEqualArray(x - 1, y) + def test_indexing(self): # Basic content check, slice indexing a_npy = np.arange(64, dtype=np.float32)