mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 00:20:45 +08:00
Make array conform to the Python Buffer Protocol (#323)
This commit is contained in:
@@ -2,12 +2,20 @@
|
||||
|
||||
import operator
|
||||
import unittest
|
||||
import weakref
|
||||
from itertools import permutations
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
has_tf = True
|
||||
except ImportError as e:
|
||||
has_tf = False
|
||||
|
||||
|
||||
class TestVersion(mlx_tests.MLXTestCase):
|
||||
def test_version(self):
|
||||
@@ -1100,7 +1108,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
# Check that we get read-only array that does not own the underlying data
|
||||
self.assertFalse(a_np.flags.owndata)
|
||||
self.assertFalse(a_np.flags.writeable)
|
||||
self.assertTrue(a_np.flags.writeable)
|
||||
|
||||
# Check contents
|
||||
self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np))
|
||||
@@ -1109,6 +1117,157 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
# Check strides
|
||||
self.assertSequenceEqual(b_np.strides, (0, 8, 4))
|
||||
|
||||
def test_np_array_conversion_copies_by_default(self):
|
||||
a_mx = mx.ones((2, 2))
|
||||
a_np = np.array(a_mx)
|
||||
self.assertTrue(a_np.flags.owndata)
|
||||
self.assertTrue(a_np.flags.writeable)
|
||||
|
||||
def test_buffer_protocol(self):
|
||||
dtypes_list = [
|
||||
(mx.bool_, np.bool_, None),
|
||||
(mx.uint8, np.uint8, np.iinfo),
|
||||
(mx.uint16, np.uint16, np.iinfo),
|
||||
(mx.uint32, np.uint32, np.iinfo),
|
||||
(mx.uint64, np.uint64, np.iinfo),
|
||||
(mx.int8, np.int8, np.iinfo),
|
||||
(mx.int16, np.int16, np.iinfo),
|
||||
(mx.int32, np.int32, np.iinfo),
|
||||
(mx.int64, np.int64, np.iinfo),
|
||||
(mx.float16, np.float16, np.finfo),
|
||||
(mx.float32, np.float32, np.finfo),
|
||||
(mx.complex64, np.complex64, np.finfo),
|
||||
]
|
||||
|
||||
for mlx_dtype, np_dtype, info_fn in dtypes_list:
|
||||
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
|
||||
if info_fn is not None:
|
||||
info = info_fn(np_dtype)
|
||||
a_np[0, 0] = info.min
|
||||
a_np[0, 1] = info.max
|
||||
a_mx = mx.array(a_np)
|
||||
for f in [lambda x: x, lambda x: x.T]:
|
||||
mv_mx = memoryview(f(a_mx))
|
||||
mv_np = memoryview(f(a_np))
|
||||
self.assertEqual(mv_mx.strides, mv_np.strides, f"{mlx_dtype}{np_dtype}")
|
||||
self.assertEqual(mv_mx.shape, mv_np.shape, f"{mlx_dtype}{np_dtype}")
|
||||
# correct buffer format for 8 byte (unsigned) 'long long' is Q/q, see
|
||||
# https://docs.python.org/3.10/library/struct.html#format-characters
|
||||
# numpy returns L/l, as 'long' is equivalent to 'long long' on 64bit machines, so q and l are equivalent
|
||||
# see https://github.com/pybind/pybind11/issues/1908
|
||||
if np_dtype == np.uint64:
|
||||
self.assertEqual(mv_mx.format, "Q", f"{mlx_dtype}{np_dtype}")
|
||||
elif np_dtype == np.int64:
|
||||
self.assertEqual(mv_mx.format, "q", f"{mlx_dtype}{np_dtype}")
|
||||
else:
|
||||
self.assertEqual(
|
||||
mv_mx.format, mv_np.format, f"{mlx_dtype}{np_dtype}"
|
||||
)
|
||||
self.assertFalse(mv_mx.readonly)
|
||||
back_to_npy = np.array(mv_mx, copy=False)
|
||||
self.assertEqualArray(
|
||||
back_to_npy,
|
||||
f(a_np),
|
||||
atol=0,
|
||||
rtol=0,
|
||||
msg=f"{mlx_dtype}{np_dtype}",
|
||||
)
|
||||
|
||||
# extra test for bfloat16, which is not numpy convertible
|
||||
a_mx = mx.random.uniform(low=0, high=100, shape=(3, 4), dtype=mx.bfloat16)
|
||||
mv_mx = memoryview(a_mx)
|
||||
self.assertEqual(mv_mx.strides, (8, 2))
|
||||
self.assertEqual(mv_mx.shape, (3, 4))
|
||||
self.assertEqual(mv_mx.format, "")
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
np.array(a_mx)
|
||||
e = cm.exception
|
||||
self.assertTrue("Item size 2 for PEP 3118 buffer format string" in str(e))
|
||||
|
||||
def test_buffer_protocol_ref_counting(self):
|
||||
a = mx.arange(3)
|
||||
wr = weakref.ref(a)
|
||||
self.assertIsNotNone(wr())
|
||||
mv = memoryview(a)
|
||||
a = None
|
||||
self.assertIsNotNone(wr())
|
||||
mv = None
|
||||
self.assertIsNone(wr())
|
||||
|
||||
def test_array_view_ref_counting(self):
|
||||
a = mx.arange(3)
|
||||
wr = weakref.ref(a)
|
||||
self.assertIsNotNone(wr())
|
||||
a_np = np.array(a, copy=False)
|
||||
a = None
|
||||
self.assertIsNotNone(wr())
|
||||
a_np = None
|
||||
self.assertIsNone(wr())
|
||||
|
||||
@unittest.skipIf(not has_tf, "requires TensorFlow")
|
||||
def test_buffer_protocol_tf(self):
|
||||
dtypes_list = [
|
||||
(
|
||||
mx.bool_,
|
||||
tf.bool,
|
||||
np.bool_,
|
||||
),
|
||||
(
|
||||
mx.uint8,
|
||||
tf.uint8,
|
||||
np.uint8,
|
||||
),
|
||||
(
|
||||
mx.uint16,
|
||||
tf.uint16,
|
||||
np.uint16,
|
||||
),
|
||||
(
|
||||
mx.uint32,
|
||||
tf.uint32,
|
||||
np.uint32,
|
||||
),
|
||||
(mx.uint64, tf.uint64, np.uint64),
|
||||
(mx.int8, tf.int8, np.int8),
|
||||
(mx.int16, tf.int16, np.int16),
|
||||
(mx.int32, tf.int32, np.int32),
|
||||
(mx.int64, tf.int64, np.int64),
|
||||
(mx.float16, tf.float16, np.float16),
|
||||
(mx.float32, tf.float32, np.float32),
|
||||
(
|
||||
mx.complex64,
|
||||
tf.complex64,
|
||||
np.complex64,
|
||||
),
|
||||
]
|
||||
|
||||
for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
|
||||
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
|
||||
a_tf = tf.constant(a_np, dtype=tf_dtype)
|
||||
a_mx = mx.array(a_tf)
|
||||
for f in [
|
||||
lambda x: x,
|
||||
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,
|
||||
]:
|
||||
mv_mx = memoryview(f(a_mx))
|
||||
mv_tf = memoryview(f(a_tf))
|
||||
if (mv_mx.c_contiguous and mv_tf.c_contiguous) or (
|
||||
mv_mx.f_contiguous and mv_tf.f_contiguous
|
||||
):
|
||||
self.assertEqual(
|
||||
mv_mx.strides, mv_tf.strides, f"{mlx_dtype}{tf_dtype}"
|
||||
)
|
||||
self.assertEqual(mv_mx.shape, mv_tf.shape, f"{mlx_dtype}{tf_dtype}")
|
||||
self.assertFalse(mv_mx.readonly)
|
||||
back_to_npy = np.array(mv_mx)
|
||||
self.assertEqualArray(
|
||||
back_to_npy,
|
||||
f(a_tf),
|
||||
atol=0,
|
||||
rtol=0,
|
||||
msg=f"{mlx_dtype}{tf_dtype}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user