diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 289c1a067..e841d0d0c 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -13,6 +13,7 @@ MLX was developed with contributions from the following individuals: - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. +- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays. - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` diff --git a/python/src/array.cpp b/python/src/array.cpp index 8e0cf605b..838970d84 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -458,6 +458,21 @@ std::vector buffer_strides(const array& a) { return py_strides; } +py::buffer_info buffer_info(array& a) { + // Eval if not already evaled + if (!a.is_evaled()) { + py::gil_scoped_release nogil; + a.eval(); + } + return pybind11::buffer_info( + a.data(), + a.itemsize(), + buffer_format(a).value_or("B"), // we use "B" because pybind uses a + // std::string which can't be null + a.ndim(), + a.shape(), + buffer_strides(a)); +} /////////////////////////////////////////////////////////////////////////////// // Module /////////////////////////////////////////////////////////////////////////////// @@ -647,21 +662,7 @@ void init_array(py::module_& m) { .def("__iter__", [](const ArrayPythonIterator& it) { return it; }); array_class - .def_buffer([](array& a) { - // Eval if not already evaled - if (!a.is_evaled()) { - py::gil_scoped_release nogil; - a.eval(); - } - return pybind11::buffer_info( - a.data(), - a.itemsize(), - buffer_format(a).value_or("B"), // we use "B" because pybind uses a - // std::string which can't be null - a.ndim(), - a.shape(), - buffer_strides(a)); - }) + .def_buffer([](array& a) -> py::buffer_info { return buffer_info(a); }) .def_property_readonly( "size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc") .def_property_readonly( @@ -773,6 +774,26 @@ void init_array(py::module_& m) { return a.shape(0); }) .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) + .def(py::pickle( + [](array& a) { // __getstate__ + if (a.dtype() == bfloat16) { + throw std::runtime_error( + "[array.__getstate__] Not supported for bfloat16."); + } + return py::array(buffer_info(a)); + }, + [](py::array npa) { // __setstate__ + if (not py::isinstance(npa)) { + throw std::runtime_error( + "[array.__setstate__] Received invalid state."); + } + return np_array_to_mlx(npa, std::nullopt); + })) + .def("__copy__", [](const array& self) { return array(self); }) + .def( + "__deepcopy__", + [](const array& self, py::dict) { return array(self); }, + "memo"_a) .def( "__add__", [](const array& a, const ScalarOrArray v) { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 828c2c0b4..4fbb2d0ae 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1,9 +1,10 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2023-2024 Apple Inc. import operator import pickle import unittest import weakref +from copy import copy, deepcopy from itertools import permutations import mlx.core as mx @@ -658,6 +659,57 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual(y.tolist(), [3.0, 4.0]) self.assertEqual(z.tolist(), [5.0, 6.0]) + def test_array_pickle(self): + dtypes = [ + mx.int8, + mx.int16, + mx.int32, + mx.int64, + mx.uint8, + mx.uint16, + mx.uint32, + mx.uint64, + mx.float16, + mx.float32, + mx.complex64, + ] + + for dtype in dtypes: + x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) + state = pickle.dumps(x) + y = pickle.loads(state) + self.assertEqualArray(y, x) + + # check if it throws an error when dtype is not supported (bfloat16) + x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16) + with self.assertRaises(RuntimeError): + pickle.dumps(x) + + def test_array_copy(self): + dtypes = [ + mx.int8, + mx.int16, + mx.int32, + mx.int64, + mx.uint8, + mx.uint16, + mx.uint32, + mx.uint64, + mx.float16, + mx.float32, + mx.bfloat16, + mx.complex64, + ] + + for copy_function in [copy, deepcopy]: + for dtype in dtypes: + x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype) + y = copy_function(x) + self.assertEqualArray(y, x) + + y -= 1 + self.assertEqualArray(y, x - 1) + def test_indexing(self): # Basic content check, slice indexing a_npy = np.arange(64, dtype=np.float32)