mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Switch to nanobind (#839)
* mostly builds * most tests pass * fix circle build * add back buffer protocol * includes * fix for py38 * limit to cpu device * include * fix stubs * move signatures for docs * stubgen + docs fix * doc for compiled function, comments
This commit is contained in:
@@ -308,9 +308,9 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y.dtype, mx.bool_)
|
||||
self.assertEqual(y.item(), True)
|
||||
|
||||
# y = mx.array(x, mx.complex64)
|
||||
# self.assertEqual(y.dtype, mx.complex64)
|
||||
# self.assertEqual(y.item(), 3.0+0j)
|
||||
y = mx.array(x, mx.complex64)
|
||||
self.assertEqual(y.dtype, mx.complex64)
|
||||
self.assertEqual(y.item(), 3.0 + 0j)
|
||||
|
||||
def test_array_repr(self):
|
||||
x = mx.array(True)
|
||||
@@ -682,7 +682,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
|
||||
# check if it throws an error when dtype is not supported (bfloat16)
|
||||
x = mx.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=mx.bfloat16)
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaises(TypeError):
|
||||
pickle.dumps(x)
|
||||
|
||||
def test_array_copy(self):
|
||||
@@ -711,6 +711,11 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqualArray(y, x - 1)
|
||||
|
||||
def test_indexing(self):
|
||||
# Only ellipsis is a no-op
|
||||
a_mlx = mx.array([1])[...]
|
||||
self.assertEqual(a_mlx.shape, (1,))
|
||||
self.assertEqual(a_mlx.item(), 1)
|
||||
|
||||
# Basic content check, slice indexing
|
||||
a_npy = np.arange(64, dtype=np.float32)
|
||||
a_mlx = mx.array(a_npy)
|
||||
@@ -1360,7 +1365,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
for mlx_dtype, tf_dtype, np_dtype in dtypes_list:
|
||||
a_np = np.random.uniform(low=0, high=100, size=(3, 4)).astype(np_dtype)
|
||||
a_tf = tf.constant(a_np, dtype=tf_dtype)
|
||||
a_mx = mx.array(a_tf)
|
||||
a_mx = mx.array(np.array(a_tf))
|
||||
for f in [
|
||||
lambda x: x,
|
||||
lambda x: tf.transpose(x) if isinstance(x, tf.Tensor) else x.T,
|
||||
|
Reference in New Issue
Block a user