Pickle with NumPy arrays

This commit is contained in:
Luca Arnaboldi
2024-03-04 13:00:18 +01:00
parent c02602a4a1
commit 6b6b4f0a5f
2 changed files with 28 additions and 34 deletions

View File

@@ -1,9 +1,10 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2024 Apple Inc.
import operator
import pickle
import unittest
import weakref
from copy import copy, deepcopy
from itertools import permutations
import mlx.core as mx
@@ -670,17 +671,15 @@ class TestArray(mlx_tests.MLXTestCase):
mx.uint64,
mx.float16,
mx.float32,
mx.bfloat16,
# mx.bfloat16,
mx.complex64,
]
import pickle
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
state = pickle.dumps(x)
y = pickle.loads(state)
self.assertEqualArray(x, y)
self.assertEqual(x.dtype, y.dtype)
self.assertEqualArray(y, x)
def test_array_copy(self):
dtypes = [
@@ -698,16 +697,14 @@ class TestArray(mlx_tests.MLXTestCase):
mx.complex64,
]
from copy import copy, deepcopy
for copy_function in [copy, deepcopy]:
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
y = copy_function(x)
self.assertEqualArray(x, y)
self.assertEqualArray(y, x)
y -= 1
self.assertEqualArray(x - 1, y)
self.assertEqualArray(y, x - 1)
def test_indexing(self):
# Basic content check, slice indexing