Implementation of pickle, copy and deepcopy for Python arrays (#300 & #367). (#713)

* Implemented pickling and copy for Python arrays(#300 & #367)

* Fixing typos

* Pickle with NumPy arrays

* Pickle: workaround for bfloat16

* Revert "Pickle: workaround for bfloat16"

This reverts commit 25afe6bc09.

* Added an error when pickling bfloat16

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* clang-format applied

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Luca Arnaboldi 2024-03-06 17:02:41 +01:00 committed by GitHub
parent e39bebe13e
commit cbefd9129e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 90 additions and 16 deletions

View File

@ -13,6 +13,7 @@ MLX was developed with contributions from the following individuals:
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention` - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@ -458,6 +458,21 @@ std::vector<size_t> buffer_strides(const array& a) {
return py_strides; return py_strides;
} }
py::buffer_info buffer_info(array& a) {
// Eval if not already evaled
if (!a.is_evaled()) {
py::gil_scoped_release nogil;
a.eval();
}
return pybind11::buffer_info(
a.data<void>(),
a.itemsize(),
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
// std::string which can't be null
a.ndim(),
a.shape(),
buffer_strides(a));
}
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
// Module // Module
/////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////
@ -647,21 +662,7 @@ void init_array(py::module_& m) {
.def("__iter__", [](const ArrayPythonIterator& it) { return it; }); .def("__iter__", [](const ArrayPythonIterator& it) { return it; });
array_class array_class
.def_buffer([](array& a) { .def_buffer([](array& a) -> py::buffer_info { return buffer_info(a); })
// Eval if not already evaled
if (!a.is_evaled()) {
py::gil_scoped_release nogil;
a.eval();
}
return pybind11::buffer_info(
a.data<void>(),
a.itemsize(),
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
// std::string which can't be null
a.ndim(),
a.shape(),
buffer_strides(a));
})
.def_property_readonly( .def_property_readonly(
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc") "size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
.def_property_readonly( .def_property_readonly(
@ -773,6 +774,26 @@ void init_array(py::module_& m) {
return a.shape(0); return a.shape(0);
}) })
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); }) .def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
.def(py::pickle(
[](array& a) { // __getstate__
if (a.dtype() == bfloat16) {
throw std::runtime_error(
"[array.__getstate__] Not supported for bfloat16.");
}
return py::array(buffer_info(a));
},
[](py::array npa) { // __setstate__
if (not py::isinstance<py::array>(npa)) {
throw std::runtime_error(
"[array.__setstate__] Received invalid state.");
}
return np_array_to_mlx(npa, std::nullopt);
}))
.def("__copy__", [](const array& self) { return array(self); })
.def(
"__deepcopy__",
[](const array& self, py::dict) { return array(self); },
"memo"_a)
.def( .def(
"__add__", "__add__",
[](const array& a, const ScalarOrArray v) { [](const array& a, const ScalarOrArray v) {

View File

@ -1,9 +1,10 @@
# Copyright © 2023 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import operator import operator
import pickle import pickle
import unittest import unittest
import weakref import weakref
from copy import copy, deepcopy
from itertools import permutations from itertools import permutations
import mlx.core as mx import mlx.core as mx
@ -658,6 +659,57 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(y.tolist(), [3.0, 4.0]) self.assertEqual(y.tolist(), [3.0, 4.0])
self.assertEqual(z.tolist(), [5.0, 6.0]) self.assertEqual(z.tolist(), [5.0, 6.0])
def test_array_pickle(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.float16,
mx.float32,
mx.complex64,
]
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
state = pickle.dumps(x)
y = pickle.loads(state)
self.assertEqualArray(y, x)
# check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(RuntimeError):
pickle.dumps(x)
def test_array_copy(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.float16,
mx.float32,
mx.bfloat16,
mx.complex64,
]
for copy_function in [copy, deepcopy]:
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
y = copy_function(x)
self.assertEqualArray(y, x)
y -= 1
self.assertEqualArray(y, x - 1)
def test_indexing(self): def test_indexing(self):
# Basic content check, slice indexing # Basic content check, slice indexing
a_npy = np.arange(64, dtype=np.float32) a_npy = np.arange(64, dtype=np.float32)