Make shape a tuple (#591)

* shape tuple

* also remove simplify from docs

* rebase
This commit is contained in:
Awni Hannun
2024-01-30 13:11:01 -08:00
committed by GitHub
parent d3a9005454
commit 09b9275027
13 changed files with 141 additions and 140 deletions

View File

@@ -12,12 +12,12 @@ import numpy as np
class TestOps(mlx_tests.MLXTestCase):
def test_full_ones_zeros(self):
x = mx.full(2, 3.0)
self.assertEqual(x.shape, [2])
self.assertEqual(x.shape, (2,))
self.assertEqual(x.tolist(), [3.0, 3.0])
x = mx.full((2, 3), 2.0)
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.shape, [2, 3])
self.assertEqual(x.shape, (2, 3))
self.assertEqual(x.tolist(), [[2, 2, 2], [2, 2, 2]])
x = mx.full([3, 2], mx.array([False, True]))
@@ -28,11 +28,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(x.tolist(), [[2, 3], [2, 3], [2, 3]])
x = mx.zeros(2)
self.assertEqual(x.shape, [2])
self.assertEqual(x.shape, (2,))
self.assertEqual(x.tolist(), [0.0, 0.0])
x = mx.ones(2)
self.assertEqual(x.shape, [2])
self.assertEqual(x.shape, (2,))
self.assertEqual(x.tolist(), [1.0, 1.0])
for t in [mx.bool_, mx.int32, mx.float32]:
@@ -530,10 +530,10 @@ class TestOps(mlx_tests.MLXTestCase):
def test_move_swap_axes(self):
x = mx.zeros((2, 3, 4))
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2])
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2])
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2])
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2])
self.assertEqual(mx.moveaxis(x, 0, 2).shape, (3, 4, 2))
self.assertEqual(x.moveaxis(0, 2).shape, (3, 4, 2))
self.assertEqual(mx.swapaxes(x, 0, 2).shape, (4, 3, 2))
self.assertEqual(x.swapaxes(0, 2).shape, (4, 3, 2))
def test_sum(self):
x = mx.array(
@@ -545,7 +545,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.sum(x).item(), 9)
y = mx.sum(x, keepdims=True)
self.assertEqual(y, mx.array(9))
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.sum(x, axis=0).tolist(), [4, 5])
self.assertEqual(mx.sum(x, axis=1).tolist(), [3, 6])
@@ -585,7 +585,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.prod(x).item(), 18)
y = mx.prod(x, keepdims=True)
self.assertEqual(y, mx.array(18))
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.prod(x, axis=0).tolist(), [3, 6])
self.assertEqual(mx.prod(x, axis=1).tolist(), [2, 9])
@@ -600,11 +600,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.min(x).item(), 1)
self.assertEqual(mx.max(x).item(), 4)
y = mx.min(x, keepdims=True)
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(y, mx.array(1))
y = mx.max(x, keepdims=True)
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(y, mx.array(4))
self.assertEqual(mx.min(x, axis=0).tolist(), [1, 2])
@@ -670,7 +670,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.mean(x).item(), 2.5)
y = mx.mean(x, keepdims=True)
self.assertEqual(y, mx.array(2.5))
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
@@ -685,7 +685,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.var(x).item(), 1.25)
y = mx.var(x, keepdims=True)
self.assertEqual(y, mx.array(1.25))
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0])
self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25])
@@ -888,7 +888,7 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.array([[True, False], [True, True]])
self.assertFalse(mx.all(a).item())
self.assertEqual(mx.all(a, keepdims=True).shape, [1, 1])
self.assertEqual(mx.all(a, keepdims=True).shape, (1, 1))
self.assertFalse(mx.all(a, axis=[0, 1]).item())
self.assertEqual(mx.all(a, axis=[0]).tolist(), [True, False])
self.assertEqual(mx.all(a, axis=[1]).tolist(), [False, True])
@@ -899,7 +899,7 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.array([[True, False], [False, False]])
self.assertTrue(mx.any(a).item())
self.assertEqual(mx.any(a, keepdims=True).shape, [1, 1])
self.assertEqual(mx.any(a, keepdims=True).shape, (1, 1))
self.assertTrue(mx.any(a, axis=[0, 1]).item())
self.assertEqual(mx.any(a, axis=[0]).tolist(), [True, False])
self.assertEqual(mx.any(a, axis=[1]).tolist(), [True, False])
@@ -956,22 +956,22 @@ class TestOps(mlx_tests.MLXTestCase):
a_npy_taken = np.take(a_npy, idx_npy)
a_mlx_taken = mx.take(a_mlx, idx_mlx)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
a_npy_taken = np.take(a_npy, idx_npy, axis=0)
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=0)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
a_npy_taken = np.take(a_npy, idx_npy, axis=1)
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=1)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
a_npy_taken = np.take(a_npy, idx_npy, axis=2)
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=2)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
def test_take_along_axis(self):
@@ -1400,13 +1400,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
a = mx.zeros((1, 1, 1))
self.assertEqual(mx.pad(a, 1).shape, [3, 3, 3])
self.assertEqual(mx.pad(a, (1,)).shape, [3, 3, 3])
self.assertEqual(mx.pad(a, [1]).shape, [3, 3, 3])
self.assertEqual(mx.pad(a, (1, 2)).shape, [4, 4, 4])
self.assertEqual(mx.pad(a, [(1, 2)]).shape, [4, 4, 4])
self.assertEqual(mx.pad(a, ((1, 2),)).shape, [4, 4, 4])
self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, [4, 4, 5])
self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, [1]).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, (1, 2)).shape, (4, 4, 4))
self.assertEqual(mx.pad(a, [(1, 2)]).shape, (4, 4, 4))
self.assertEqual(mx.pad(a, ((1, 2),)).shape, (4, 4, 4))
self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, (4, 4, 5))
# Test grads
a_fwd = mx.array(np.random.rand(16, 16).astype(np.float32))
@@ -1490,19 +1490,19 @@ class TestOps(mlx_tests.MLXTestCase):
def test_squeeze_expand(self):
a = mx.zeros((2, 1, 2, 1))
self.assertEqual(mx.squeeze(a).shape, [2, 2])
self.assertEqual(mx.squeeze(a, 1).shape, [2, 2, 1])
self.assertEqual(mx.squeeze(a, [1, 3]).shape, [2, 2])
self.assertEqual(a.squeeze().shape, [2, 2])
self.assertEqual(a.squeeze(1).shape, [2, 2, 1])
self.assertEqual(a.squeeze([1, 3]).shape, [2, 2])
self.assertEqual(mx.squeeze(a).shape, (2, 2))
self.assertEqual(mx.squeeze(a, 1).shape, (2, 2, 1))
self.assertEqual(mx.squeeze(a, [1, 3]).shape, (2, 2))
self.assertEqual(a.squeeze().shape, (2, 2))
self.assertEqual(a.squeeze(1).shape, (2, 2, 1))
self.assertEqual(a.squeeze([1, 3]).shape, (2, 2))
a = mx.zeros((2, 2))
self.assertEqual(mx.squeeze(a).shape, [2, 2])
self.assertEqual(mx.squeeze(a).shape, (2, 2))
self.assertEqual(mx.expand_dims(a, 0).shape, [1, 2, 2])
self.assertEqual(mx.expand_dims(a, (0, 1)).shape, [1, 1, 2, 2])
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, [1, 2, 2, 1])
self.assertEqual(mx.expand_dims(a, 0).shape, (1, 2, 2))
self.assertEqual(mx.expand_dims(a, (0, 1)).shape, (1, 1, 2, 2))
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
def test_sort(self):
shape = (3, 4, 5)
@@ -1603,12 +1603,12 @@ class TestOps(mlx_tests.MLXTestCase):
def test_flatten(self):
x = mx.zeros([2, 3, 4])
self.assertEqual(mx.flatten(x).shape, [2 * 3 * 4])
self.assertEqual(mx.flatten(x, start_axis=1).shape, [2, 3 * 4])
self.assertEqual(mx.flatten(x, end_axis=1).shape, [2 * 3, 4])
self.assertEqual(x.flatten().shape, [2 * 3 * 4])
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
self.assertEqual(mx.flatten(x).shape, (2 * 3 * 4,))
self.assertEqual(mx.flatten(x, start_axis=1).shape, (2, 3 * 4))
self.assertEqual(mx.flatten(x, end_axis=1).shape, (2 * 3, 4))
self.assertEqual(x.flatten().shape, (2 * 3 * 4,))
self.assertEqual(x.flatten(start_axis=1).shape, (2, 3 * 4))
self.assertEqual(x.flatten(end_axis=1).shape, (2 * 3, 4))
def test_clip(self):
a = np.array([1, 4, 3, 8, 5], np.int32)