mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-27 08:18:30 +08:00
Support pickling array for bfloat16 (#2586)
* add bfloat16 pickling * Improvements * improve --------- Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
This commit is contained in:
@@ -532,7 +532,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(str(x), expected)
|
||||
|
||||
x = mx.array([[1, 2], [1, 2], [1, 2]])
|
||||
expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)"
|
||||
expected = "array([[1, 2],\n [1, 2],\n [1, 2]], dtype=int32)"
|
||||
self.assertEqual(str(x), expected)
|
||||
|
||||
x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])
|
||||
@@ -886,6 +886,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
mx.uint64,
|
||||
mx.float16,
|
||||
mx.float32,
|
||||
mx.bfloat16,
|
||||
mx.complex64,
|
||||
]
|
||||
|
||||
@@ -895,11 +896,6 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
y = pickle.loads(state)
|
||||
self.assertEqualArray(y, x)
|
||||
|
||||
# check if it throws an error when dtype is not supported (bfloat16)
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
|
||||
with self.assertRaises(TypeError):
|
||||
pickle.dumps(x)
|
||||
|
||||
def test_array_copy(self):
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
|
Reference in New Issue
Block a user