Make shape a tuple (#591)

* shape tuple

* also remove simplify from docs

* rebase
This commit is contained in:
Awni Hannun
2024-01-30 13:11:01 -08:00
committed by GitHub
parent d3a9005454
commit 09b9275027
13 changed files with 141 additions and 140 deletions

View File

@@ -94,7 +94,7 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(x.ndim, 0)
self.assertEqual(x.itemsize, 4)
self.assertEqual(x.nbytes, 4)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), int))
@@ -116,7 +116,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(1.0)
self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.item(), 1.0)
self.assertTrue(isinstance(x.item(), float))
@@ -124,14 +124,14 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(False)
self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.bool_)
self.assertEqual(x.item(), False)
self.assertTrue(isinstance(x.item(), bool))
x = mx.array(complex(1, 1))
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.complex64)
self.assertEqual(x.item(), complex(1, 1))
self.assertTrue(isinstance(x.item(), complex))
@@ -139,7 +139,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([True, False, True])
self.assertEqual(x.dtype, mx.bool_)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [3])
self.assertEqual(x.shape, (3,))
self.assertEqual(len(x), 3)
x = mx.array([True, False, True], mx.float32)
@@ -148,7 +148,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([0, 1, 2])
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [3])
self.assertEqual(x.shape, (3,))
x = mx.array([0, 1, 2], mx.float32)
self.assertEqual(x.dtype, mx.float32)
@@ -156,12 +156,12 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([0.0, 1.0, 2.0])
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [3])
self.assertEqual(x.shape, (3,))
x = mx.array([1j, 1 + 0j])
self.assertEqual(x.dtype, mx.complex64)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [2])
self.assertEqual(x.shape, (2,))
# From tuple
x = mx.array((1, 2, 3), mx.int32)
@@ -181,17 +181,17 @@ class TestArray(mlx_tests.MLXTestCase):
def test_construction_from_lists(self):
x = mx.array([])
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [0])
self.assertEqual(x.shape, (0,))
self.assertEqual(x.dtype, mx.float32)
x = mx.array([[], [], []])
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 0])
self.assertEqual(x.shape, (3, 0))
self.assertEqual(x.dtype, mx.float32)
x = mx.array([[[], []], [[], []], [[], []]])
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 2, 0])
self.assertEqual(x.shape, (3, 2, 0))
self.assertEqual(x.dtype, mx.float32)
# Check failure cases
@@ -436,19 +436,19 @@ class TestArray(mlx_tests.MLXTestCase):
a = np.array([])
x = mx.array(a)
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [0])
self.assertEqual(x.shape, (0,))
self.assertEqual(x.dtype, mx.float32)
a = np.array([[], [], []])
x = mx.array(a)
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 0])
self.assertEqual(x.shape, (3, 0))
self.assertEqual(x.dtype, mx.float32)
a = np.array([[[], []], [[], []], [[], []]])
x = mx.array(a)
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 2, 0])
self.assertEqual(x.shape, (3, 2, 0))
self.assertEqual(x.dtype, mx.float32)
# Content test
@@ -456,7 +456,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(a)
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.ndim, 3)
self.assertEqual(x.shape, [3, 5, 4])
self.assertEqual(x.shape, (3, 5, 4))
y = np.asarray(x)
self.assertTrue(np.allclose(a, y))
@@ -465,7 +465,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(a)
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.item(), 3)
# mlx to numpy test
@@ -483,7 +483,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = np.array(cvals)
y = mx.array(x)
self.assertEqual(y.dtype, mx.complex64)
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.tolist(), cvals)
y = mx.array([0j, 1, 1 + 1j])