Implementation of pickle, copy and deepcopy for Python arrays (#300 & #367). (#713)

* Implemented pickling and copy for Python arrays(#300 & #367)

* Fixing typos

* Pickle with NumPy arrays

* Pickle: workaround for bfloat16

* Revert "Pickle: workaround for bfloat16"

This reverts commit 25afe6bc09.

* Added an error when pickling bfloat16

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/tests/test_array.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* Update python/src/array.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* clang-format applied

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Luca Arnaboldi
2024-03-06 17:02:41 +01:00
committed by GitHub
parent e39bebe13e
commit cbefd9129e
3 changed files with 90 additions and 16 deletions

View File

@@ -1,9 +1,10 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import operator
import pickle
import unittest
import weakref
from copy import copy, deepcopy
from itertools import permutations
import mlx.core as mx
@@ -658,6 +659,57 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(y.tolist(), [3.0, 4.0])
self.assertEqual(z.tolist(), [5.0, 6.0])
def test_array_pickle(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.float16,
mx.float32,
mx.complex64,
]
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
state = pickle.dumps(x)
y = pickle.loads(state)
self.assertEqualArray(y, x)
# check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(RuntimeError):
pickle.dumps(x)
def test_array_copy(self):
dtypes = [
mx.int8,
mx.int16,
mx.int32,
mx.int64,
mx.uint8,
mx.uint16,
mx.uint32,
mx.uint64,
mx.float16,
mx.float32,
mx.bfloat16,
mx.complex64,
]
for copy_function in [copy, deepcopy]:
for dtype in dtypes:
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
y = copy_function(x)
self.assertEqualArray(y, x)
y -= 1
self.assertEqualArray(y, x - 1)
def test_indexing(self):
# Basic content check, slice indexing
a_npy = np.arange(64, dtype=np.float32)