mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
* Implemented pickling and copy for Python arrays(#300 & #367)
* Fixing typos
* Pickle with NumPy arrays
* Pickle: workaround for bfloat16
* Revert "Pickle: workaround for bfloat16"
This reverts commit 25afe6bc09
.
* Added an error when pickling bfloat16
* Update python/tests/test_array.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/tests/test_array.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/array.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/array.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* clang-format applied
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
e39bebe13e
commit
cbefd9129e
@ -13,6 +13,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
||||||
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
||||||
|
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
||||||
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
@ -458,6 +458,21 @@ std::vector<size_t> buffer_strides(const array& a) {
|
|||||||
return py_strides;
|
return py_strides;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
py::buffer_info buffer_info(array& a) {
|
||||||
|
// Eval if not already evaled
|
||||||
|
if (!a.is_evaled()) {
|
||||||
|
py::gil_scoped_release nogil;
|
||||||
|
a.eval();
|
||||||
|
}
|
||||||
|
return pybind11::buffer_info(
|
||||||
|
a.data<void>(),
|
||||||
|
a.itemsize(),
|
||||||
|
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
|
||||||
|
// std::string which can't be null
|
||||||
|
a.ndim(),
|
||||||
|
a.shape(),
|
||||||
|
buffer_strides(a));
|
||||||
|
}
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
// Module
|
// Module
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
@ -647,21 +662,7 @@ void init_array(py::module_& m) {
|
|||||||
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
|
.def("__iter__", [](const ArrayPythonIterator& it) { return it; });
|
||||||
|
|
||||||
array_class
|
array_class
|
||||||
.def_buffer([](array& a) {
|
.def_buffer([](array& a) -> py::buffer_info { return buffer_info(a); })
|
||||||
// Eval if not already evaled
|
|
||||||
if (!a.is_evaled()) {
|
|
||||||
py::gil_scoped_release nogil;
|
|
||||||
a.eval();
|
|
||||||
}
|
|
||||||
return pybind11::buffer_info(
|
|
||||||
a.data<void>(),
|
|
||||||
a.itemsize(),
|
|
||||||
buffer_format(a).value_or("B"), // we use "B" because pybind uses a
|
|
||||||
// std::string which can't be null
|
|
||||||
a.ndim(),
|
|
||||||
a.shape(),
|
|
||||||
buffer_strides(a));
|
|
||||||
})
|
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
"size", &array::size, R"pbdoc(Number of elements in the array.)pbdoc")
|
||||||
.def_property_readonly(
|
.def_property_readonly(
|
||||||
@ -773,6 +774,26 @@ void init_array(py::module_& m) {
|
|||||||
return a.shape(0);
|
return a.shape(0);
|
||||||
})
|
})
|
||||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||||
|
.def(py::pickle(
|
||||||
|
[](array& a) { // __getstate__
|
||||||
|
if (a.dtype() == bfloat16) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[array.__getstate__] Not supported for bfloat16.");
|
||||||
|
}
|
||||||
|
return py::array(buffer_info(a));
|
||||||
|
},
|
||||||
|
[](py::array npa) { // __setstate__
|
||||||
|
if (not py::isinstance<py::array>(npa)) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[array.__setstate__] Received invalid state.");
|
||||||
|
}
|
||||||
|
return np_array_to_mlx(npa, std::nullopt);
|
||||||
|
}))
|
||||||
|
.def("__copy__", [](const array& self) { return array(self); })
|
||||||
|
.def(
|
||||||
|
"__deepcopy__",
|
||||||
|
[](const array& self, py::dict) { return array(self); },
|
||||||
|
"memo"_a)
|
||||||
.def(
|
.def(
|
||||||
"__add__",
|
"__add__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import operator
|
import operator
|
||||||
import pickle
|
import pickle
|
||||||
import unittest
|
import unittest
|
||||||
import weakref
|
import weakref
|
||||||
|
from copy import copy, deepcopy
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -658,6 +659,57 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.tolist(), [3.0, 4.0])
|
self.assertEqual(y.tolist(), [3.0, 4.0])
|
||||||
self.assertEqual(z.tolist(), [5.0, 6.0])
|
self.assertEqual(z.tolist(), [5.0, 6.0])
|
||||||
|
|
||||||
|
def test_array_pickle(self):
|
||||||
|
dtypes = [
|
||||||
|
mx.int8,
|
||||||
|
mx.int16,
|
||||||
|
mx.int32,
|
||||||
|
mx.int64,
|
||||||
|
mx.uint8,
|
||||||
|
mx.uint16,
|
||||||
|
mx.uint32,
|
||||||
|
mx.uint64,
|
||||||
|
mx.float16,
|
||||||
|
mx.float32,
|
||||||
|
mx.complex64,
|
||||||
|
]
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||||
|
state = pickle.dumps(x)
|
||||||
|
y = pickle.loads(state)
|
||||||
|
self.assertEqualArray(y, x)
|
||||||
|
|
||||||
|
# check if it throws an error when dtype is not supported (bfloat16)
|
||||||
|
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
pickle.dumps(x)
|
||||||
|
|
||||||
|
def test_array_copy(self):
|
||||||
|
dtypes = [
|
||||||
|
mx.int8,
|
||||||
|
mx.int16,
|
||||||
|
mx.int32,
|
||||||
|
mx.int64,
|
||||||
|
mx.uint8,
|
||||||
|
mx.uint16,
|
||||||
|
mx.uint32,
|
||||||
|
mx.uint64,
|
||||||
|
mx.float16,
|
||||||
|
mx.float32,
|
||||||
|
mx.bfloat16,
|
||||||
|
mx.complex64,
|
||||||
|
]
|
||||||
|
|
||||||
|
for copy_function in [copy, deepcopy]:
|
||||||
|
for dtype in dtypes:
|
||||||
|
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||||
|
y = copy_function(x)
|
||||||
|
self.assertEqualArray(y, x)
|
||||||
|
|
||||||
|
y -= 1
|
||||||
|
self.assertEqualArray(y, x - 1)
|
||||||
|
|
||||||
def test_indexing(self):
|
def test_indexing(self):
|
||||||
# Basic content check, slice indexing
|
# Basic content check, slice indexing
|
||||||
a_npy = np.arange(64, dtype=np.float32)
|
a_npy = np.arange(64, dtype=np.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user