Switch to nanobind (#839)

* mostly builds

* most tests pass

* fix circle build

* add back buffer protocol

* includes

* fix for py38

* limit to cpu device

* include

* fix stubs

* move signatures for docs

* stubgen + docs fix

* doc for compiled function, comments
This commit is contained in:
Awni Hannun
2024-03-18 20:12:25 -07:00
committed by GitHub
parent d39ed54f8e
commit 9a8ee00246
34 changed files with 2343 additions and 2344 deletions

View File

@@ -308,9 +308,9 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(y.dtype, mx.bool_)
self.assertEqual(y.item(), True)
# y = mx.array(x, mx.complex64)
# self.assertEqual(y.dtype, mx.complex64)
# self.assertEqual(y.item(), 3.0+0j)
y = mx.array(x, mx.complex64)
self.assertEqual(y.dtype, mx.complex64)
self.assertEqual(y.item(), 3.0 + 0j)
def test_array_repr(self):
x = mx.array(True)
@@ -682,7 +682,7 @@ class TestArray(mlx_tests.MLXTestCase):
# check if it throws an error when dtype is not supported (bfloat16)
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
with self.assertRaises(RuntimeError):
with self.assertRaises(TypeError):
pickle.dumps(x)
def test_array_copy(self):
@@ -711,6 +711,11 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqualArray(y, x - 1)
def test_indexing(self):
# Only ellipsis is a no-op
a_mlx = mx.array([1])[...]
self.assertEqual(a_mlx.shape, (1,))
self.assertEqual(a_mlx.item(), 1)
# Basic content check, slice indexing
a_npy = np.arange(64, dtype=np.float32)
a_mlx = mx.array(a_npy)
@@ -1360,7 +1365,7 @@ class TestArray(mlx_tests.MLXTestCase):
for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
a_tf = tf.constant(a_np, dtype=tf_dtype)
a_mx = mx.array(a_tf)
a_mx = mx.array(np.array(a_tf))
for f in [
lambda x: x,
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,