mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-04 15:04:40 +08:00
Make shape a tuple (#591)
* shape tuple * also remove simplify from docs * rebase
This commit is contained in:
@@ -12,12 +12,12 @@ import numpy as np
|
||||
class TestOps(mlx_tests.MLXTestCase):
|
||||
def test_full_ones_zeros(self):
|
||||
x = mx.full(2, 3.0)
|
||||
self.assertEqual(x.shape, [2])
|
||||
self.assertEqual(x.shape, (2,))
|
||||
self.assertEqual(x.tolist(), [3.0, 3.0])
|
||||
|
||||
x = mx.full((2, 3), 2.0)
|
||||
self.assertEqual(x.dtype, mx.float32)
|
||||
self.assertEqual(x.shape, [2, 3])
|
||||
self.assertEqual(x.shape, (2, 3))
|
||||
self.assertEqual(x.tolist(), [[2, 2, 2], [2, 2, 2]])
|
||||
|
||||
x = mx.full([3, 2], mx.array([False, True]))
|
||||
@@ -28,11 +28,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(x.tolist(), [[2, 3], [2, 3], [2, 3]])
|
||||
|
||||
x = mx.zeros(2)
|
||||
self.assertEqual(x.shape, [2])
|
||||
self.assertEqual(x.shape, (2,))
|
||||
self.assertEqual(x.tolist(), [0.0, 0.0])
|
||||
|
||||
x = mx.ones(2)
|
||||
self.assertEqual(x.shape, [2])
|
||||
self.assertEqual(x.shape, (2,))
|
||||
self.assertEqual(x.tolist(), [1.0, 1.0])
|
||||
|
||||
for t in [mx.bool_, mx.int32, mx.float32]:
|
||||
@@ -530,10 +530,10 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_move_swap_axes(self):
|
||||
x = mx.zeros((2, 3, 4))
|
||||
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2])
|
||||
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2])
|
||||
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2])
|
||||
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2])
|
||||
self.assertEqual(mx.moveaxis(x, 0, 2).shape, (3, 4, 2))
|
||||
self.assertEqual(x.moveaxis(0, 2).shape, (3, 4, 2))
|
||||
self.assertEqual(mx.swapaxes(x, 0, 2).shape, (4, 3, 2))
|
||||
self.assertEqual(x.swapaxes(0, 2).shape, (4, 3, 2))
|
||||
|
||||
def test_sum(self):
|
||||
x = mx.array(
|
||||
@@ -545,7 +545,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.sum(x).item(), 9)
|
||||
y = mx.sum(x, keepdims=True)
|
||||
self.assertEqual(y, mx.array(9))
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
self.assertEqual(mx.sum(x, axis=0).tolist(), [4, 5])
|
||||
self.assertEqual(mx.sum(x, axis=1).tolist(), [3, 6])
|
||||
@@ -585,7 +585,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.prod(x).item(), 18)
|
||||
y = mx.prod(x, keepdims=True)
|
||||
self.assertEqual(y, mx.array(18))
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
self.assertEqual(mx.prod(x, axis=0).tolist(), [3, 6])
|
||||
self.assertEqual(mx.prod(x, axis=1).tolist(), [2, 9])
|
||||
@@ -600,11 +600,11 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.min(x).item(), 1)
|
||||
self.assertEqual(mx.max(x).item(), 4)
|
||||
y = mx.min(x, keepdims=True)
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
self.assertEqual(y, mx.array(1))
|
||||
|
||||
y = mx.max(x, keepdims=True)
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
self.assertEqual(y, mx.array(4))
|
||||
|
||||
self.assertEqual(mx.min(x, axis=0).tolist(), [1, 2])
|
||||
@@ -670,7 +670,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.mean(x).item(), 2.5)
|
||||
y = mx.mean(x, keepdims=True)
|
||||
self.assertEqual(y, mx.array(2.5))
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
|
||||
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
|
||||
@@ -685,7 +685,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(mx.var(x).item(), 1.25)
|
||||
y = mx.var(x, keepdims=True)
|
||||
self.assertEqual(y, mx.array(1.25))
|
||||
self.assertEqual(y.shape, [1, 1])
|
||||
self.assertEqual(y.shape, (1, 1))
|
||||
|
||||
self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0])
|
||||
self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25])
|
||||
@@ -888,7 +888,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.array([[True, False], [True, True]])
|
||||
|
||||
self.assertFalse(mx.all(a).item())
|
||||
self.assertEqual(mx.all(a, keepdims=True).shape, [1, 1])
|
||||
self.assertEqual(mx.all(a, keepdims=True).shape, (1, 1))
|
||||
self.assertFalse(mx.all(a, axis=[0, 1]).item())
|
||||
self.assertEqual(mx.all(a, axis=[0]).tolist(), [True, False])
|
||||
self.assertEqual(mx.all(a, axis=[1]).tolist(), [False, True])
|
||||
@@ -899,7 +899,7 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
a = mx.array([[True, False], [False, False]])
|
||||
|
||||
self.assertTrue(mx.any(a).item())
|
||||
self.assertEqual(mx.any(a, keepdims=True).shape, [1, 1])
|
||||
self.assertEqual(mx.any(a, keepdims=True).shape, (1, 1))
|
||||
self.assertTrue(mx.any(a, axis=[0, 1]).item())
|
||||
self.assertEqual(mx.any(a, axis=[0]).tolist(), [True, False])
|
||||
self.assertEqual(mx.any(a, axis=[1]).tolist(), [True, False])
|
||||
@@ -956,22 +956,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
a_npy_taken = np.take(a_npy, idx_npy)
|
||||
a_mlx_taken = mx.take(a_mlx, idx_mlx)
|
||||
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
a_npy_taken = np.take(a_npy, idx_npy, axis=0)
|
||||
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=0)
|
||||
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
a_npy_taken = np.take(a_npy, idx_npy, axis=1)
|
||||
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=1)
|
||||
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
a_npy_taken = np.take(a_npy, idx_npy, axis=2)
|
||||
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=2)
|
||||
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
|
||||
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
|
||||
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
|
||||
|
||||
def test_take_along_axis(self):
|
||||
@@ -1400,13 +1400,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
|
||||
|
||||
a = mx.zeros((1, 1, 1))
|
||||
self.assertEqual(mx.pad(a, 1).shape, [3, 3, 3])
|
||||
self.assertEqual(mx.pad(a, (1,)).shape, [3, 3, 3])
|
||||
self.assertEqual(mx.pad(a, [1]).shape, [3, 3, 3])
|
||||
self.assertEqual(mx.pad(a, (1, 2)).shape, [4, 4, 4])
|
||||
self.assertEqual(mx.pad(a, [(1, 2)]).shape, [4, 4, 4])
|
||||
self.assertEqual(mx.pad(a, ((1, 2),)).shape, [4, 4, 4])
|
||||
self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, [4, 4, 5])
|
||||
self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
|
||||
self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))
|
||||
self.assertEqual(mx.pad(a, [1]).shape, (3, 3, 3))
|
||||
self.assertEqual(mx.pad(a, (1, 2)).shape, (4, 4, 4))
|
||||
self.assertEqual(mx.pad(a, [(1, 2)]).shape, (4, 4, 4))
|
||||
self.assertEqual(mx.pad(a, ((1, 2),)).shape, (4, 4, 4))
|
||||
self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, (4, 4, 5))
|
||||
|
||||
# Test grads
|
||||
a_fwd = mx.array(np.random.rand(16, 16).astype(np.float32))
|
||||
@@ -1490,19 +1490,19 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_squeeze_expand(self):
|
||||
a = mx.zeros((2, 1, 2, 1))
|
||||
self.assertEqual(mx.squeeze(a).shape, [2, 2])
|
||||
self.assertEqual(mx.squeeze(a, 1).shape, [2, 2, 1])
|
||||
self.assertEqual(mx.squeeze(a, [1, 3]).shape, [2, 2])
|
||||
self.assertEqual(a.squeeze().shape, [2, 2])
|
||||
self.assertEqual(a.squeeze(1).shape, [2, 2, 1])
|
||||
self.assertEqual(a.squeeze([1, 3]).shape, [2, 2])
|
||||
self.assertEqual(mx.squeeze(a).shape, (2, 2))
|
||||
self.assertEqual(mx.squeeze(a, 1).shape, (2, 2, 1))
|
||||
self.assertEqual(mx.squeeze(a, [1, 3]).shape, (2, 2))
|
||||
self.assertEqual(a.squeeze().shape, (2, 2))
|
||||
self.assertEqual(a.squeeze(1).shape, (2, 2, 1))
|
||||
self.assertEqual(a.squeeze([1, 3]).shape, (2, 2))
|
||||
|
||||
a = mx.zeros((2, 2))
|
||||
self.assertEqual(mx.squeeze(a).shape, [2, 2])
|
||||
self.assertEqual(mx.squeeze(a).shape, (2, 2))
|
||||
|
||||
self.assertEqual(mx.expand_dims(a, 0).shape, [1, 2, 2])
|
||||
self.assertEqual(mx.expand_dims(a, (0, 1)).shape, [1, 1, 2, 2])
|
||||
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, [1, 2, 2, 1])
|
||||
self.assertEqual(mx.expand_dims(a, 0).shape, (1, 2, 2))
|
||||
self.assertEqual(mx.expand_dims(a, (0, 1)).shape, (1, 1, 2, 2))
|
||||
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
|
||||
|
||||
def test_sort(self):
|
||||
shape = (3, 4, 5)
|
||||
@@ -1603,12 +1603,12 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
|
||||
def test_flatten(self):
|
||||
x = mx.zeros([2, 3, 4])
|
||||
self.assertEqual(mx.flatten(x).shape, [2 * 3 * 4])
|
||||
self.assertEqual(mx.flatten(x, start_axis=1).shape, [2, 3 * 4])
|
||||
self.assertEqual(mx.flatten(x, end_axis=1).shape, [2 * 3, 4])
|
||||
self.assertEqual(x.flatten().shape, [2 * 3 * 4])
|
||||
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
|
||||
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
|
||||
self.assertEqual(mx.flatten(x).shape, (2 * 3 * 4,))
|
||||
self.assertEqual(mx.flatten(x, start_axis=1).shape, (2, 3 * 4))
|
||||
self.assertEqual(mx.flatten(x, end_axis=1).shape, (2 * 3, 4))
|
||||
self.assertEqual(x.flatten().shape, (2 * 3 * 4,))
|
||||
self.assertEqual(x.flatten(start_axis=1).shape, (2, 3 * 4))
|
||||
self.assertEqual(x.flatten(end_axis=1).shape, (2 * 3, 4))
|
||||
|
||||
def test_clip(self):
|
||||
a = np.array([1, 4, 3, 8, 5], np.int32)
|
||||
|
Reference in New Issue
Block a user