Support pickling array for bfloat16 (#2586)

* add bfloat16 pickling

* Improvements

* improve

---------

Co-authored-by: Chen-Chen Yeh <ge96noj@mytum.de>
This commit is contained in:
Daniel Yeh
2025-09-23 05:12:15 +02:00
committed by GitHub
parent bf01ad9367
commit fbbf3b9b3e
4 changed files with 36 additions and 12 deletions

View File

@@ -532,7 +532,7 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(str(x), expected)
x = mx.array([[1, 2], [1, 2], [1, 2]])
expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)"
expected = "array([[1, 2],\n [1, 2],\n [1, 2]], dtype=int32)"
self.assertEqual(str(x), expected)
x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]])
@@ -886,6 +886,7 @@ class TestArray(mlx_tests.MLXTestCase):
mx.uint64,
mx.float16,
mx.float32,
mx.bfloat16,
mx.complex64,
]
@@ -895,11 +896,6 @@ class TestArray(mlx_tests.MLXTestCase):
y = pickle.loads(state)
self.assertEqualArray(y, x)
# check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(TypeError):
pickle.dumps(x)
def test_array_copy(self):
dtypes = [
mx.int8,