Implemented pickling and copy for Python arrays(#300 & #367)

This commit is contained in:
Luca Arnaboldi
2024-02-20 09:25:07 +01:00
parent 146bd69470
commit 8ba3625a40
3 changed files with 74 additions and 0 deletions

View File

@@ -649,6 +649,58 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(y.tolist(), [3.0, 4.0])
self.assertEqual(z.tolist(), [5.0, 6.0])
def test_array_pickle(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
# mx.float16,
mx.float32,
# mx.bfloat16,
mx.complex64,
]
import pickle
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
state = pickle.dumps(x)
y = pickle.loads(state)
self.assertEqualArray(x, y)
self.assertEqual(x.dtype, y.dtype)
def test_array_copy(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
# mx.float16,
mx.float32,
# mx.bfloat16,
mx.complex64,
]
from copy import copy, deepcopy
for copy_function in [copy, deepcopy]:
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
y = copy_function(x)
self.assertEqualArray(x, y)
y -= 1
print(x, y)
self.assertEqualArray(x - 1, y)
def test_indexing(self):
# Basic content check, slice indexing
a_npy = np.arange(64, dtype=np.float32)