mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
* Implemented pickling and copy for Python arrays(#300 & #367)
* Fixing typos
* Pickle with NumPy arrays
* Pickle: workaround for bfloat16
* Revert "Pickle: workaround for bfloat16"
This reverts commit 25afe6bc09
.
* Added an error when pickling bfloat16
* Update python/tests/test_array.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/tests/test_array.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/array.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* Update python/src/array.cpp
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
* clang-format applied
---------
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import operator
|
||||
import pickle
|
||||
import unittest
|
||||
import weakref
|
||||
from copy import copy, deepcopy
|
||||
from itertools import permutations
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -658,6 +659,57 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.tolist(), [3.0, 4.0])
|
||||
self.assertEqual(z.tolist(), [5.0, 6.0])
|
||||
|
||||
def test_array_pickle(self):
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.int16,
|
||||
mx.int32,
|
||||
mx.int64,
|
||||
mx.uint8,
|
||||
mx.uint16,
|
||||
mx.uint32,
|
||||
mx.uint64,
|
||||
mx.float16,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
|
||||
for dtype in dtypes:
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||
state = pickle.dumps(x)
|
||||
y = pickle.loads(state)
|
||||
self.assertEqualArray(y, x)
|
||||
|
||||
# check if it throws an error when dtype is not supported (bfloat16)
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
|
||||
with self.assertRaises(RuntimeError):
|
||||
pickle.dumps(x)
|
||||
|
||||
def test_array_copy(self):
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.int16,
|
||||
mx.int32,
|
||||
mx.int64,
|
||||
mx.uint8,
|
||||
mx.uint16,
|
||||
mx.uint32,
|
||||
mx.uint64,
|
||||
mx.float16,
|
||||
mx.float32,
|
||||
mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
|
||||
for copy_function in [copy, deepcopy]:
|
||||
for dtype in dtypes:
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtype)
|
||||
y = copy_function(x)
|
||||
self.assertEqualArray(y, x)
|
||||
|
||||
y -= 1
|
||||
self.assertEqualArray(y, x - 1)
|
||||
|
||||
def test_indexing(self):
|
||||
# Basic content check, slice indexing
|
||||
a_npy = np.arange(64, dtype=np.float32)
|
||||
|
Reference in New Issue
Block a user