mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-16 22:11:15 +08:00
parent
146bd69470
commit
8ba3625a40
@ -13,6 +13,9 @@ MLX was developed with contributions from the following individuals:
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||
|
||||
- Luca Arnaboldi: Added `Ceil` and `Floor` ops.
|
||||
Implemented pickling, copy and deepcopy for Python arrays.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
</a>
|
||||
|
@ -774,6 +774,25 @@ void init_array(py::module_& m) {
|
||||
return a.shape(0);
|
||||
})
|
||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||
.def(py::pickle(
|
||||
[](array& a) { // __getstate__
|
||||
return py::make_tuple(
|
||||
dtype_to_array_protocol(a.dtype()), tolist(a));
|
||||
},
|
||||
[](py::tuple t) { // __setstate__
|
||||
if (t.size() != 2 or !py::isinstance<py::str>(t[0]) or
|
||||
!py::isinstance<py::list>(t[1]))
|
||||
throw std::invalid_argument(
|
||||
"Invalide state for __setstate__. Expected a tuple of length 2 with a string and a list as elements.");
|
||||
|
||||
return array_from_list(
|
||||
t[1], dtype_from_array_protocol(t[0].cast<std::string>()));
|
||||
}))
|
||||
.def("__copy__", [](const array& self) { return array(self); })
|
||||
.def(
|
||||
"__deepcopy__",
|
||||
[](const array& self, py::dict) { return array(self); },
|
||||
"memo"_a)
|
||||
.def(
|
||||
"__add__",
|
||||
[](const array& a, const ScalarOrArray v) {
|
||||
|
@ -649,6 +649,58 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.tolist(), [3.0, 4.0])
|
||||
self.assertEqual(z.tolist(), [5.0, 6.0])
|
||||
|
||||
def test_array_pickle(self):
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.int16,
|
||||
mx.int32,
|
||||
mx.int64,
|
||||
mx.uint8,
|
||||
mx.uint16,
|
||||
mx.uint32,
|
||||
mx.uint64,
|
||||
# mx.float16,
|
||||
mx.float32,
|
||||
# mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
import pickle
|
||||
|
||||
for dtype in dtypes:
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||
state = pickle.dumps(x)
|
||||
y = pickle.loads(state)
|
||||
self.assertEqualArray(x, y)
|
||||
self.assertEqual(x.dtype, y.dtype)
|
||||
|
||||
def test_array_copy(self):
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.int16,
|
||||
mx.int32,
|
||||
mx.int64,
|
||||
mx.uint8,
|
||||
mx.uint16,
|
||||
mx.uint32,
|
||||
mx.uint64,
|
||||
# mx.float16,
|
||||
mx.float32,
|
||||
# mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
|
||||
from copy import copy, deepcopy
|
||||
|
||||
for copy_function in [copy, deepcopy]:
|
||||
for dtype in dtypes:
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||
y = copy_function(x)
|
||||
self.assertEqualArray(x, y)
|
||||
|
||||
y -= 1
|
||||
print(x, y)
|
||||
self.assertEqualArray(x - 1, y)
|
||||
|
||||
def test_indexing(self):
|
||||
# Basic content check, slice indexing
|
||||
a_npy = np.arange(64, dtype=np.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user