mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
Make shape a tuple (#591)
* shape tuple * also remove simplify from docs * rebase
This commit is contained in:
@@ -94,7 +94,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.itemsize, 4)
|
||||
self.assertEqual(x.nbytes, 4)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.dtype, mx.int32)
|
||||
self.assertEqual(x.item(), 1)
|
||||
self.assertTrue(isinstance(x.item(), int))
|
||||
@@ -116,7 +116,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(1.0)
|
||||
self.assertEqual(x.size, 1)
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
self.assertEqual(x.item(), 1.0)
|
||||
self.assertTrue(isinstance(x.item(), float))
|
||||
@@ -124,14 +124,14 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(False)
|
||||
self.assertEqual(x.size, 1)
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.dtype, mx.bool_)
|
||||
self.assertEqual(x.item(), False)
|
||||
self.assertTrue(isinstance(x.item(), bool))
|
||||
|
||||
x = mx.array(complex(1, 1))
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.dtype, mx.complex64)
|
||||
self.assertEqual(x.item(), complex(1, 1))
|
||||
self.assertTrue(isinstance(x.item(), complex))
|
||||
@@ -139,7 +139,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array([True, False, True])
|
||||
self.assertEqual(x.dtype, mx.bool_)
|
||||
self.assertEqual(x.ndim, 1)
|
||||
self.assertEqual(x.shape, [3])
|
||||
self.assertEqual(x.shape, (3,))
|
||||
self.assertEqual(len(x), 3)
|
||||
|
||||
x = mx.array([True, False, True], mx.float32)
|
||||
@@ -148,7 +148,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array([0, 1, 2])
|
||||
self.assertEqual(x.dtype, mx.int32)
|
||||
self.assertEqual(x.ndim, 1)
|
||||
self.assertEqual(x.shape, [3])
|
||||
self.assertEqual(x.shape, (3,))
|
||||
|
||||
x = mx.array([0, 1, 2], mx.float32)
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
@@ -156,12 +156,12 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array([0.0, 1.0, 2.0])
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
self.assertEqual(x.ndim, 1)
|
||||
self.assertEqual(x.shape, [3])
|
||||
self.assertEqual(x.shape, (3,))
|
||||
|
||||
x = mx.array([1j, 1 + 0j])
|
||||
self.assertEqual(x.dtype, mx.complex64)
|
||||
self.assertEqual(x.ndim, 1)
|
||||
self.assertEqual(x.shape, [2])
|
||||
self.assertEqual(x.shape, (2,))
|
||||
|
||||
# From tuple
|
||||
x = mx.array((1, 2, 3), mx.int32)
|
||||
@@ -181,17 +181,17 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
def test_construction_from_lists(self):
|
||||
x = mx.array([])
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [0])
|
||||
self.assertEqual(x.shape, (0,))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
x = mx.array([[], [], []])
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [3, 0])
|
||||
self.assertEqual(x.shape, (3, 0))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
x = mx.array([[[], []], [[], []], [[], []]])
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [3, 2, 0])
|
||||
self.assertEqual(x.shape, (3, 2, 0))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
# Check failure cases
|
||||
@@ -436,19 +436,19 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
a = np.array([])
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [0])
|
||||
self.assertEqual(x.shape, (0,))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
a = np.array([[], [], []])
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [3, 0])
|
||||
self.assertEqual(x.shape, (3, 0))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
a = np.array([[[], []], [[], []], [[], []]])
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.size, 0)
|
||||
self.assertEqual(x.shape, [3, 2, 0])
|
||||
self.assertEqual(x.shape, (3, 2, 0))
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
|
||||
# Content test
|
||||
@@ -456,7 +456,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
self.assertEqual(x.ndim, 3)
|
||||
self.assertEqual(x.shape, [3, 5, 4])
|
||||
self.assertEqual(x.shape, (3, 5, 4))
|
||||
|
||||
y = np.asarray(x)
|
||||
self.assertTrue(np.allclose(a, y))
|
||||
@@ -465,7 +465,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = mx.array(a)
|
||||
self.assertEqual(x.dtype, mx.int32)
|
||||
self.assertEqual(x.ndim, 0)
|
||||
self.assertEqual(x.shape, [])
|
||||
self.assertEqual(x.shape, ())
|
||||
self.assertEqual(x.item(), 3)
|
||||
|
||||
# mlx to numpy test
|
||||
@@ -483,7 +483,7 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
x = np.array(cvals)
|
||||
y = mx.array(x)
|
||||
self.assertEqual(y.dtype, mx.complex64)
|
||||
self.assertEqual(y.shape, [3])
|
||||
self.assertEqual(y.shape, (3,))
|
||||
self.assertEqual(y.tolist(), cvals)
|
||||
|
||||
y = mx.array([0j, 1, 1 + 1j])
|
||||
|
Reference in New Issue
Block a user