mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 15:11:14 +08:00
parent
146bd69470
commit
8ba3625a40
@ -13,6 +13,9 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile` and safetensor support
|
||||||
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer.
|
||||||
|
|
||||||
|
- Luca Arnaboldi: Added `Ceil` and `Floor` ops.
|
||||||
|
Implemented pickling, copy and deepcopy for Python arrays.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
</a>
|
</a>
|
||||||
|
@ -774,6 +774,25 @@ void init_array(py::module_& m) {
|
|||||||
return a.shape(0);
|
return a.shape(0);
|
||||||
})
|
})
|
||||||
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
.def("__iter__", [](const array& a) { return ArrayPythonIterator(a); })
|
||||||
|
.def(py::pickle(
|
||||||
|
[](array& a) { // __getstate__
|
||||||
|
return py::make_tuple(
|
||||||
|
dtype_to_array_protocol(a.dtype()), tolist(a));
|
||||||
|
},
|
||||||
|
[](py::tuple t) { // __setstate__
|
||||||
|
if (t.size() != 2 or !py::isinstance<py::str>(t[0]) or
|
||||||
|
!py::isinstance<py::list>(t[1]))
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Invalide state for __setstate__. Expected a tuple of length 2 with a string and a list as elements.");
|
||||||
|
|
||||||
|
return array_from_list(
|
||||||
|
t[1], dtype_from_array_protocol(t[0].cast<std::string>()));
|
||||||
|
}))
|
||||||
|
.def("__copy__", [](const array& self) { return array(self); })
|
||||||
|
.def(
|
||||||
|
"__deepcopy__",
|
||||||
|
[](const array& self, py::dict) { return array(self); },
|
||||||
|
"memo"_a)
|
||||||
.def(
|
.def(
|
||||||
"__add__",
|
"__add__",
|
||||||
[](const array& a, const ScalarOrArray v) {
|
[](const array& a, const ScalarOrArray v) {
|
||||||
|
@ -649,6 +649,58 @@ class TestArray(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y.tolist(), [3.0, 4.0])
|
self.assertEqual(y.tolist(), [3.0, 4.0])
|
||||||
self.assertEqual(z.tolist(), [5.0, 6.0])
|
self.assertEqual(z.tolist(), [5.0, 6.0])
|
||||||
|
|
||||||
|
def test_array_pickle(self):
|
||||||
|
dtypes = [
|
||||||
|
mx.int8,
|
||||||
|
mx.int16,
|
||||||
|
mx.int32,
|
||||||
|
mx.int64,
|
||||||
|
mx.uint8,
|
||||||
|
mx.uint16,
|
||||||
|
mx.uint32,
|
||||||
|
mx.uint64,
|
||||||
|
# mx.float16,
|
||||||
|
mx.float32,
|
||||||
|
# mx.bfloat16,
|
||||||
|
mx.complex64,
|
||||||
|
]
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||||
|
state = pickle.dumps(x)
|
||||||
|
y = pickle.loads(state)
|
||||||
|
self.assertEqualArray(x, y)
|
||||||
|
self.assertEqual(x.dtype, y.dtype)
|
||||||
|
|
||||||
|
def test_array_copy(self):
|
||||||
|
dtypes = [
|
||||||
|
mx.int8,
|
||||||
|
mx.int16,
|
||||||
|
mx.int32,
|
||||||
|
mx.int64,
|
||||||
|
mx.uint8,
|
||||||
|
mx.uint16,
|
||||||
|
mx.uint32,
|
||||||
|
mx.uint64,
|
||||||
|
# mx.float16,
|
||||||
|
mx.float32,
|
||||||
|
# mx.bfloat16,
|
||||||
|
mx.complex64,
|
||||||
|
]
|
||||||
|
|
||||||
|
from copy import copy, deepcopy
|
||||||
|
|
||||||
|
for copy_function in [copy, deepcopy]:
|
||||||
|
for dtype in dtypes:
|
||||||
|
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||||
|
y = copy_function(x)
|
||||||
|
self.assertEqualArray(x, y)
|
||||||
|
|
||||||
|
y -= 1
|
||||||
|
print(x, y)
|
||||||
|
self.assertEqualArray(x - 1, y)
|
||||||
|
|
||||||
def test_indexing(self):
|
def test_indexing(self):
|
||||||
# Basic content check, slice indexing
|
# Basic content check, slice indexing
|
||||||
a_npy = np.arange(64, dtype=np.float32)
|
a_npy = np.arange(64, dtype=np.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user