From 09b9275027f34864702005c3d9d5a1bd9daf2c21 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 30 Jan 2024 13:11:01 -0800 Subject: [PATCH] Make shape a tuple (#591) * shape tuple * also remove simplify from docs * rebase --- docs/src/dev/extensions.rst | 2 +- docs/src/python/transforms.rst | 1 - docs/src/usage/lazy_evaluation.rst | 2 +- python/mlx/nn/layers/dropout.py | 2 +- python/mlx/nn/losses.py | 3 +- python/src/array.cpp | 7 +-- python/tests/test_array.py | 34 ++++++------ python/tests/test_blas.py | 6 ++- python/tests/test_conv.py | 24 ++++----- python/tests/test_init.py | 16 +++--- python/tests/test_nn.py | 54 +++++++++---------- python/tests/test_ops.py | 86 +++++++++++++++--------------- python/tests/test_random.py | 44 +++++++-------- 13 files changed, 141 insertions(+), 140 deletions(-) diff --git a/docs/src/dev/extensions.rst b/docs/src/dev/extensions.rst index 0a134e7f5..a7880e396 100644 --- a/docs/src/dev/extensions.rst +++ b/docs/src/dev/extensions.rst @@ -929,7 +929,7 @@ We see some modest improvements right away! This operation is now good to be used to build other operations, in :class:`mlx.nn.Module` calls, and also as a part of graph -transformations such as :meth:`grad` and :meth:`simplify`! +transformations like :meth:`grad`! Scripts ------- diff --git a/docs/src/python/transforms.rst b/docs/src/python/transforms.rst index fa6d1d701..cc8d681d5 100644 --- a/docs/src/python/transforms.rst +++ b/docs/src/python/transforms.rst @@ -14,4 +14,3 @@ Transforms jvp vjp vmap - simplify diff --git a/docs/src/usage/lazy_evaluation.rst b/docs/src/usage/lazy_evaluation.rst index 4f14ceeed..e41fcbe0b 100644 --- a/docs/src/usage/lazy_evaluation.rst +++ b/docs/src/usage/lazy_evaluation.rst @@ -20,7 +20,7 @@ Transforming Compute Graphs Lazy evaluation let's us record a compute graph without actually doing any computations. This is useful for function transformations like :func:`grad` and -:func:`vmap` and graph optimizations like :func:`simplify`. +:func:`vmap` and graph optimizations. Currently, MLX does not compile and rerun compute graphs. They are all generated dynamically. However, lazy evaluation makes it much easier to diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index 18b9b03a6..7008547c0 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -81,7 +81,7 @@ class Dropout2d(Module): # Dropout is applied on the whole channel # 3D input: (1, 1, C) # 4D input: (B, 1, 1, C) - mask_shape = x.shape + mask_shape = list(x.shape) mask_shape[-2] = mask_shape[-3] = 1 mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index d68092f4a..ae64ab3ac 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -70,8 +70,9 @@ def cross_entropy( targets_as_probs = targets.ndim == logits.ndim def _drop_dim(shape, axis): + shape = list(shape) shape.pop(axis) - return shape + return tuple(shape) # Check shapes in two cases: targets as class indices and targets as probabilities if (targets_as_probs and targets.shape != logits.shape) or ( diff --git a/python/src/array.cpp b/python/src/array.cpp index acb4f8edc..9e2f1e923 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -675,17 +675,14 @@ void init_array(py::module_& m) { "nbytes", &array::nbytes, R"pbdoc(The number of bytes in the array.)pbdoc") - // TODO, this makes a deep copy of the shape - // implement alternatives to use reference - // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html .def_property_readonly( "shape", - [](const array& a) { return a.shape(); }, + [](const array& a) { return py::tuple(py::cast(a.shape())); }, R"pbdoc( The shape of the array as a Python list. Returns: - list(int): A list containing the sizes of each dimension. + tuple(int): A tuple containing the sizes of each dimension. )pbdoc") .def_property_readonly( "dtype", diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 593dde361..507675d6e 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -94,7 +94,7 @@ class TestArray(mlx_tests.MLXTestCase): self.assertEqual(x.ndim, 0) self.assertEqual(x.itemsize, 4) self.assertEqual(x.nbytes, 4) - self.assertEqual(x.shape, []) + self.assertEqual(x.shape, ()) self.assertEqual(x.dtype, mx.int32) self.assertEqual(x.item(), 1) self.assertTrue(isinstance(x.item(), int)) @@ -116,7 +116,7 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(1.0) self.assertEqual(x.size, 1) self.assertEqual(x.ndim, 0) - self.assertEqual(x.shape, []) + self.assertEqual(x.shape, ()) self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.item(), 1.0) self.assertTrue(isinstance(x.item(), float)) @@ -124,14 +124,14 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(False) self.assertEqual(x.size, 1) self.assertEqual(x.ndim, 0) - self.assertEqual(x.shape, []) + self.assertEqual(x.shape, ()) self.assertEqual(x.dtype, mx.bool_) self.assertEqual(x.item(), False) self.assertTrue(isinstance(x.item(), bool)) x = mx.array(complex(1, 1)) self.assertEqual(x.ndim, 0) - self.assertEqual(x.shape, []) + self.assertEqual(x.shape, ()) self.assertEqual(x.dtype, mx.complex64) self.assertEqual(x.item(), complex(1, 1)) self.assertTrue(isinstance(x.item(), complex)) @@ -139,7 +139,7 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array([True, False, True]) self.assertEqual(x.dtype, mx.bool_) self.assertEqual(x.ndim, 1) - self.assertEqual(x.shape, [3]) + self.assertEqual(x.shape, (3,)) self.assertEqual(len(x), 3) x = mx.array([True, False, True], mx.float32) @@ -148,7 +148,7 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array([0, 1, 2]) self.assertEqual(x.dtype, mx.int32) self.assertEqual(x.ndim, 1) - self.assertEqual(x.shape, [3]) + self.assertEqual(x.shape, (3,)) x = mx.array([0, 1, 2], mx.float32) self.assertEqual(x.dtype, mx.float32) @@ -156,12 +156,12 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array([0.0, 1.0, 2.0]) self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.ndim, 1) - self.assertEqual(x.shape, [3]) + self.assertEqual(x.shape, (3,)) x = mx.array([1j, 1 + 0j]) self.assertEqual(x.dtype, mx.complex64) self.assertEqual(x.ndim, 1) - self.assertEqual(x.shape, [2]) + self.assertEqual(x.shape, (2,)) # From tuple x = mx.array((1, 2, 3), mx.int32) @@ -181,17 +181,17 @@ class TestArray(mlx_tests.MLXTestCase): def test_construction_from_lists(self): x = mx.array([]) self.assertEqual(x.size, 0) - self.assertEqual(x.shape, [0]) + self.assertEqual(x.shape, (0,)) self.assertEqual(x.dtype, mx.float32) x = mx.array([[], [], []]) self.assertEqual(x.size, 0) - self.assertEqual(x.shape, [3, 0]) + self.assertEqual(x.shape, (3, 0)) self.assertEqual(x.dtype, mx.float32) x = mx.array([[[], []], [[], []], [[], []]]) self.assertEqual(x.size, 0) - self.assertEqual(x.shape, [3, 2, 0]) + self.assertEqual(x.shape, (3, 2, 0)) self.assertEqual(x.dtype, mx.float32) # Check failure cases @@ -436,19 +436,19 @@ class TestArray(mlx_tests.MLXTestCase): a = np.array([]) x = mx.array(a) self.assertEqual(x.size, 0) - self.assertEqual(x.shape, [0]) + self.assertEqual(x.shape, (0,)) self.assertEqual(x.dtype, mx.float32) a = np.array([[], [], []]) x = mx.array(a) self.assertEqual(x.size, 0) - self.assertEqual(x.shape, [3, 0]) + self.assertEqual(x.shape, (3, 0)) self.assertEqual(x.dtype, mx.float32) a = np.array([[[], []], [[], []], [[], []]]) x = mx.array(a) self.assertEqual(x.size, 0) - self.assertEqual(x.shape, [3, 2, 0]) + self.assertEqual(x.shape, (3, 2, 0)) self.assertEqual(x.dtype, mx.float32) # Content test @@ -456,7 +456,7 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(a) self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.ndim, 3) - self.assertEqual(x.shape, [3, 5, 4]) + self.assertEqual(x.shape, (3, 5, 4)) y = np.asarray(x) self.assertTrue(np.allclose(a, y)) @@ -465,7 +465,7 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(a) self.assertEqual(x.dtype, mx.int32) self.assertEqual(x.ndim, 0) - self.assertEqual(x.shape, []) + self.assertEqual(x.shape, ()) self.assertEqual(x.item(), 3) # mlx to numpy test @@ -483,7 +483,7 @@ class TestArray(mlx_tests.MLXTestCase): x = np.array(cvals) y = mx.array(x) self.assertEqual(y.dtype, mx.complex64) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.tolist(), cvals) y = mx.array([0j, 1, 1 + 1j]) diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index d1ac0a3a1..b7a24caf2 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -579,5 +579,9 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) for r, t in zip(dout_ref, dout_test): - self.assertListEqual(r.shape, t.shape) + self.assertEqual(r.shape, t.shape) self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_conv.py b/python/tests/test_conv.py index 1501334e9..4ccee863c 100644 --- a/python/tests/test_conv.py +++ b/python/tests/test_conv.py @@ -44,7 +44,7 @@ class TestConv(mlx_tests.MLXTestCase): c_np = np.convolve(a_np, v_np, mode=mode) c_mx = mx.convolve(a_mx, v_mx, mode=mode) - self.assertListEqual(list(c_mx.shape), list(c_np.shape)) + self.assertEqual(c_mx.shape, c_np.shape) self.assertTrue(np.allclose(c_mx, c_np, atol=atol)) @unittest.skipIf(not has_torch, "requires Torch") @@ -102,7 +102,7 @@ class TestConv(mlx_tests.MLXTestCase): ) out_pt = torch.transpose(out_pt, 2, 1) - self.assertListEqual(list(out_pt.shape), out_mx.shape) + self.assertEqual(out_pt.shape, out_mx.shape) self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) for dtype in ("float32",): @@ -141,7 +141,7 @@ class TestConv(mlx_tests.MLXTestCase): out_pt = torch.conv1d(in_pt, wt_pt) out_pt = torch.transpose(out_pt, 2, 1) - self.assertListEqual(list(out_pt.shape), out_mx.shape) + self.assertEqual(out_pt.shape, out_mx.shape) self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5)) @unittest.skipIf(not has_torch, "requires Torch") @@ -228,12 +228,12 @@ class TestConv(mlx_tests.MLXTestCase): mx_grad_in, mx_grad_wt = outs_mx - self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape) - self.assertListEqual(list(in_mx.shape), mx_grad_in.shape) + self.assertEqual(pt_grad_in.shape, mx_grad_in.shape) + self.assertEqual(in_mx.shape, mx_grad_in.shape) self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) - self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape) - self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape) + self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape) + self.assertEqual(wt_mx.shape, mx_grad_wt.shape) self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) for dtype in ("float32",): @@ -309,7 +309,7 @@ class TestConv(mlx_tests.MLXTestCase): ) out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) - self.assertListEqual(list(out_pt.shape), list(out_mx.shape)) + self.assertEqual(out_pt.shape, out_mx.shape) self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) for dtype in ("float32",): @@ -419,12 +419,12 @@ class TestConv(mlx_tests.MLXTestCase): mx_grad_in, mx_grad_wt = outs_mx - self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape) - self.assertListEqual(list(in_mx.shape), mx_grad_in.shape) + self.assertEqual(pt_grad_in.shape, mx_grad_in.shape) + self.assertEqual(in_mx.shape, mx_grad_in.shape) self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) - self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape) - self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape) + self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape) + self.assertEqual(wt_mx.shape, mx_grad_wt.shape) self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) for dtype in ("float32",): diff --git a/python/tests/test_init.py b/python/tests/test_init.py index 06211a14e..3cc63e03d 100644 --- a/python/tests/test_init.py +++ b/python/tests/test_init.py @@ -13,7 +13,7 @@ class TestInit(mlx_tests.MLXTestCase): for dtype in [mx.float32, mx.float16]: initializer = init.constant(value, dtype) - for shape in [[3], [3, 3], [3, 3, 3]]: + for shape in [(3,), (3, 3), (3, 3, 3)]: result = initializer(mx.array(mx.zeros(shape))) with self.subTest(shape=shape): self.assertEqual(result.shape, shape) @@ -24,7 +24,7 @@ class TestInit(mlx_tests.MLXTestCase): std = 1.0 for dtype in [mx.float32, mx.float16]: initializer = init.normal(mean, std, dtype=dtype) - for shape in [[3], [3, 3], [3, 3, 3]]: + for shape in [(3,), (3, 3), (3, 3, 3)]: result = initializer(mx.array(np.empty(shape))) with self.subTest(shape=shape): self.assertEqual(result.shape, shape) @@ -36,7 +36,7 @@ class TestInit(mlx_tests.MLXTestCase): for dtype in [mx.float32, mx.float16]: initializer = init.uniform(low, high, dtype) - for shape in [[3], [3, 3], [3, 3, 3]]: + for shape in [(3,), (3, 3), (3, 3, 3)]: result = initializer(mx.array(np.empty(shape))) with self.subTest(shape=shape): self.assertEqual(result.shape, shape) @@ -46,7 +46,7 @@ class TestInit(mlx_tests.MLXTestCase): def test_identity(self): for dtype in [mx.float32, mx.float16]: initializer = init.identity(dtype) - for shape in [[3], [3, 3], [3, 3, 3]]: + for shape in [(3,), (3, 3), (3, 3, 3)]: result = initializer(mx.zeros((3, 3))) self.assertTrue(mx.array_equal(result, mx.eye(3))) self.assertEqual(result.dtype, dtype) @@ -56,7 +56,7 @@ class TestInit(mlx_tests.MLXTestCase): def test_glorot_normal(self): for dtype in [mx.float32, mx.float16]: initializer = init.glorot_normal(dtype) - for shape in [[3, 3], [3, 3, 3]]: + for shape in [(3, 3), (3, 3, 3)]: result = initializer(mx.array(np.empty(shape))) with self.subTest(shape=shape): self.assertEqual(result.shape, shape) @@ -65,7 +65,7 @@ class TestInit(mlx_tests.MLXTestCase): def test_glorot_uniform(self): for dtype in [mx.float32, mx.float16]: initializer = init.glorot_uniform(dtype) - for shape in [[3, 3], [3, 3, 3]]: + for shape in [(3, 3), (3, 3, 3)]: result = initializer(mx.array(np.empty(shape))) with self.subTest(shape=shape): self.assertEqual(result.shape, shape) @@ -74,7 +74,7 @@ class TestInit(mlx_tests.MLXTestCase): def test_he_normal(self): for dtype in [mx.float32, mx.float16]: initializer = init.he_normal(dtype) - for shape in [[3, 3], [3, 3, 3]]: + for shape in [(3, 3), (3, 3, 3)]: result = initializer(mx.array(np.empty(shape))) with self.subTest(shape=shape): self.assertEqual(result.shape, shape) @@ -83,7 +83,7 @@ class TestInit(mlx_tests.MLXTestCase): def test_he_uniform(self): for dtype in [mx.float32, mx.float16]: initializer = init.he_uniform(dtype) - for shape in [[3, 3], [3, 3, 3]]: + for shape in [(3, 3), (3, 3, 3)]: result = initializer(mx.array(np.empty(shape))) with self.subTest(shape=shape): self.assertEqual(result.shape, shape) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 6d65e19bd..7749e159a 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -136,20 +136,20 @@ class TestLayers(mlx_tests.MLXTestCase): inputs = mx.zeros((10, 4)) layer = nn.Identity() outputs = layer(inputs) - self.assertEqual(tuple(inputs.shape), tuple(outputs.shape)) + self.assertEqual(inputs.shape, outputs.shape) def test_linear(self): inputs = mx.zeros((10, 4)) layer = nn.Linear(input_dims=4, output_dims=8) outputs = layer(inputs) - self.assertEqual(tuple(outputs.shape), (10, 8)) + self.assertEqual(outputs.shape, (10, 8)) def test_bilinear(self): inputs1 = mx.zeros((10, 2)) inputs2 = mx.zeros((10, 4)) layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6) outputs = layer(inputs1, inputs2) - self.assertEqual(tuple(outputs.shape), (10, 6)) + self.assertEqual(outputs.shape, (10, 6)) def test_group_norm(self): x = mx.arange(100, dtype=mx.float32) @@ -573,12 +573,12 @@ class TestLayers(mlx_tests.MLXTestCase): c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks) c.weight = mx.ones_like(c.weight) y = c(x) - self.assertEqual(y.shape, [N, L - ks + 1, C_out]) + self.assertEqual(y.shape, (N, L - ks + 1, C_out)) self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32))) c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2) y = c(x) - self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out]) + self.assertEqual(y.shape, (N, (L - ks + 1) // 2, C_out)) self.assertTrue("bias" in c.parameters()) c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False) @@ -588,7 +588,7 @@ class TestLayers(mlx_tests.MLXTestCase): x = mx.ones((4, 8, 8, 3)) c = nn.Conv2d(3, 1, 8) y = c(x) - self.assertEqual(y.shape, [4, 1, 1, 1]) + self.assertEqual(y.shape, (4, 1, 1, 1)) c.weight = mx.ones_like(c.weight) / 8 / 8 / 3 y = c(x) self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3)))) @@ -596,13 +596,13 @@ class TestLayers(mlx_tests.MLXTestCase): # 3x3 conv no padding stride 1 c = nn.Conv2d(3, 8, 3) y = c(x) - self.assertEqual(y.shape, [4, 6, 6, 8]) + self.assertEqual(y.shape, (4, 6, 6, 8)) self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) # 3x3 conv padding 1 stride 1 c = nn.Conv2d(3, 8, 3, padding=1) y = c(x) - self.assertEqual(y.shape, [4, 8, 8, 8]) + self.assertEqual(y.shape, (4, 8, 8, 8)) self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4) self.assertLess( mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(), @@ -624,14 +624,14 @@ class TestLayers(mlx_tests.MLXTestCase): # 3x3 conv no padding stride 2 c = nn.Conv2d(3, 8, 3, padding=0, stride=2) y = c(x) - self.assertEqual(y.shape, [4, 3, 3, 8]) + self.assertEqual(y.shape, (4, 3, 3, 8)) self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) def test_sequential(self): x = mx.ones((10, 2)) m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1)) y = m(x) - self.assertEqual(y.shape, [10, 1]) + self.assertEqual(y.shape, (10, 1)) params = m.parameters() self.assertTrue("layers" in params) self.assertEqual(len(params["layers"]), 3) @@ -667,7 +667,7 @@ class TestLayers(mlx_tests.MLXTestCase): x = mx.arange(10) y = m(x) - self.assertEqual(y.shape, [10, 16]) + self.assertEqual(y.shape, (10, 16)) similarities = y @ y.T self.assertLess( mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5 @@ -686,19 +686,19 @@ class TestLayers(mlx_tests.MLXTestCase): x = mx.array([1.0, -1.0, 0.0]) y = nn.relu(x) self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0]))) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_leaky_relu(self): x = mx.array([1.0, -1.0, 0.0]) y = nn.leaky_relu(x) self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0]))) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) y = nn.LeakyReLU(negative_slope=0.1)(x) self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0]))) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_elu(self): @@ -707,21 +707,21 @@ class TestLayers(mlx_tests.MLXTestCase): epsilon = 1e-4 expected_y = mx.array([1.0, -0.6321, 0.0]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) y = nn.ELU(alpha=1.1)(x) epsilon = 1e-4 expected_y = mx.array([1.0, -0.6953, 0.0]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_relu6(self): x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0]) y = nn.relu6(x) self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0]))) - self.assertEqual(y.shape, [5]) + self.assertEqual(y.shape, (5,)) self.assertEqual(y.dtype, mx.float32) def test_softmax(self): @@ -730,7 +730,7 @@ class TestLayers(mlx_tests.MLXTestCase): epsilon = 1e-4 expected_y = mx.array([0.6652, 0.0900, 0.2447]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_softplus(self): @@ -739,7 +739,7 @@ class TestLayers(mlx_tests.MLXTestCase): epsilon = 1e-4 expected_y = mx.array([1.3133, 0.3133, 0.6931]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_softsign(self): @@ -748,7 +748,7 @@ class TestLayers(mlx_tests.MLXTestCase): epsilon = 1e-4 expected_y = mx.array([0.5, -0.5, 0.0]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_softshrink(self): @@ -757,13 +757,13 @@ class TestLayers(mlx_tests.MLXTestCase): epsilon = 1e-4 expected_y = mx.array([0.5, -0.5, 0.0]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) y = nn.Softshrink(lambd=0.7)(x) expected_y = mx.array([0.3, -0.3, 0.0]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_celu(self): @@ -772,13 +772,13 @@ class TestLayers(mlx_tests.MLXTestCase): epsilon = 1e-4 expected_y = mx.array([1.0, -0.6321, 0.0]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) y = nn.CELU(alpha=1.1)(x) expected_y = mx.array([1.0, -0.6568, 0.0]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_log_softmax(self): @@ -787,7 +787,7 @@ class TestLayers(mlx_tests.MLXTestCase): epsilon = 1e-4 expected_y = mx.array([-2.4076, -1.4076, -0.4076]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_log_sigmoid(self): @@ -796,7 +796,7 @@ class TestLayers(mlx_tests.MLXTestCase): epsilon = 1e-4 expected_y = mx.array([-0.3133, -1.3133, -0.6931]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [3]) + self.assertEqual(y.shape, (3,)) self.assertEqual(y.dtype, mx.float32) def test_prelu(self): @@ -817,7 +817,7 @@ class TestLayers(mlx_tests.MLXTestCase): epsilon = 1e-4 expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0]) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) - self.assertEqual(y.shape, [5]) + self.assertEqual(y.shape, (5,)) self.assertEqual(y.dtype, mx.float32) def test_glu(self): diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index c82d9b5c5..64152b537 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -12,12 +12,12 @@ import numpy as np class TestOps(mlx_tests.MLXTestCase): def test_full_ones_zeros(self): x = mx.full(2, 3.0) - self.assertEqual(x.shape, [2]) + self.assertEqual(x.shape, (2,)) self.assertEqual(x.tolist(), [3.0, 3.0]) x = mx.full((2, 3), 2.0) self.assertEqual(x.dtype, mx.float32) - self.assertEqual(x.shape, [2, 3]) + self.assertEqual(x.shape, (2, 3)) self.assertEqual(x.tolist(), [[2, 2, 2], [2, 2, 2]]) x = mx.full([3, 2], mx.array([False, True])) @@ -28,11 +28,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(x.tolist(), [[2, 3], [2, 3], [2, 3]]) x = mx.zeros(2) - self.assertEqual(x.shape, [2]) + self.assertEqual(x.shape, (2,)) self.assertEqual(x.tolist(), [0.0, 0.0]) x = mx.ones(2) - self.assertEqual(x.shape, [2]) + self.assertEqual(x.shape, (2,)) self.assertEqual(x.tolist(), [1.0, 1.0]) for t in [mx.bool_, mx.int32, mx.float32]: @@ -530,10 +530,10 @@ class TestOps(mlx_tests.MLXTestCase): def test_move_swap_axes(self): x = mx.zeros((2, 3, 4)) - self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2]) - self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2]) - self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2]) - self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2]) + self.assertEqual(mx.moveaxis(x, 0, 2).shape, (3, 4, 2)) + self.assertEqual(x.moveaxis(0, 2).shape, (3, 4, 2)) + self.assertEqual(mx.swapaxes(x, 0, 2).shape, (4, 3, 2)) + self.assertEqual(x.swapaxes(0, 2).shape, (4, 3, 2)) def test_sum(self): x = mx.array( @@ -545,7 +545,7 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.sum(x).item(), 9) y = mx.sum(x, keepdims=True) self.assertEqual(y, mx.array(9)) - self.assertEqual(y.shape, [1, 1]) + self.assertEqual(y.shape, (1, 1)) self.assertEqual(mx.sum(x, axis=0).tolist(), [4, 5]) self.assertEqual(mx.sum(x, axis=1).tolist(), [3, 6]) @@ -585,7 +585,7 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.prod(x).item(), 18) y = mx.prod(x, keepdims=True) self.assertEqual(y, mx.array(18)) - self.assertEqual(y.shape, [1, 1]) + self.assertEqual(y.shape, (1, 1)) self.assertEqual(mx.prod(x, axis=0).tolist(), [3, 6]) self.assertEqual(mx.prod(x, axis=1).tolist(), [2, 9]) @@ -600,11 +600,11 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.min(x).item(), 1) self.assertEqual(mx.max(x).item(), 4) y = mx.min(x, keepdims=True) - self.assertEqual(y.shape, [1, 1]) + self.assertEqual(y.shape, (1, 1)) self.assertEqual(y, mx.array(1)) y = mx.max(x, keepdims=True) - self.assertEqual(y.shape, [1, 1]) + self.assertEqual(y.shape, (1, 1)) self.assertEqual(y, mx.array(4)) self.assertEqual(mx.min(x, axis=0).tolist(), [1, 2]) @@ -670,7 +670,7 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.mean(x).item(), 2.5) y = mx.mean(x, keepdims=True) self.assertEqual(y, mx.array(2.5)) - self.assertEqual(y.shape, [1, 1]) + self.assertEqual(y.shape, (1, 1)) self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3]) self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5]) @@ -685,7 +685,7 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(mx.var(x).item(), 1.25) y = mx.var(x, keepdims=True) self.assertEqual(y, mx.array(1.25)) - self.assertEqual(y.shape, [1, 1]) + self.assertEqual(y.shape, (1, 1)) self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0]) self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25]) @@ -888,7 +888,7 @@ class TestOps(mlx_tests.MLXTestCase): a = mx.array([[True, False], [True, True]]) self.assertFalse(mx.all(a).item()) - self.assertEqual(mx.all(a, keepdims=True).shape, [1, 1]) + self.assertEqual(mx.all(a, keepdims=True).shape, (1, 1)) self.assertFalse(mx.all(a, axis=[0, 1]).item()) self.assertEqual(mx.all(a, axis=[0]).tolist(), [True, False]) self.assertEqual(mx.all(a, axis=[1]).tolist(), [False, True]) @@ -899,7 +899,7 @@ class TestOps(mlx_tests.MLXTestCase): a = mx.array([[True, False], [False, False]]) self.assertTrue(mx.any(a).item()) - self.assertEqual(mx.any(a, keepdims=True).shape, [1, 1]) + self.assertEqual(mx.any(a, keepdims=True).shape, (1, 1)) self.assertTrue(mx.any(a, axis=[0, 1]).item()) self.assertEqual(mx.any(a, axis=[0]).tolist(), [True, False]) self.assertEqual(mx.any(a, axis=[1]).tolist(), [True, False]) @@ -956,22 +956,22 @@ class TestOps(mlx_tests.MLXTestCase): a_npy_taken = np.take(a_npy, idx_npy) a_mlx_taken = mx.take(a_mlx, idx_mlx) - self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) + self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape) self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) a_npy_taken = np.take(a_npy, idx_npy, axis=0) a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=0) - self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) + self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape) self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) a_npy_taken = np.take(a_npy, idx_npy, axis=1) a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=1) - self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) + self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape) self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) a_npy_taken = np.take(a_npy, idx_npy, axis=2) a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=2) - self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) + self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape) self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) def test_take_along_axis(self): @@ -1400,13 +1400,13 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) a = mx.zeros((1, 1, 1)) - self.assertEqual(mx.pad(a, 1).shape, [3, 3, 3]) - self.assertEqual(mx.pad(a, (1,)).shape, [3, 3, 3]) - self.assertEqual(mx.pad(a, [1]).shape, [3, 3, 3]) - self.assertEqual(mx.pad(a, (1, 2)).shape, [4, 4, 4]) - self.assertEqual(mx.pad(a, [(1, 2)]).shape, [4, 4, 4]) - self.assertEqual(mx.pad(a, ((1, 2),)).shape, [4, 4, 4]) - self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, [4, 4, 5]) + self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3)) + self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3)) + self.assertEqual(mx.pad(a, [1]).shape, (3, 3, 3)) + self.assertEqual(mx.pad(a, (1, 2)).shape, (4, 4, 4)) + self.assertEqual(mx.pad(a, [(1, 2)]).shape, (4, 4, 4)) + self.assertEqual(mx.pad(a, ((1, 2),)).shape, (4, 4, 4)) + self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, (4, 4, 5)) # Test grads a_fwd = mx.array(np.random.rand(16, 16).astype(np.float32)) @@ -1490,19 +1490,19 @@ class TestOps(mlx_tests.MLXTestCase): def test_squeeze_expand(self): a = mx.zeros((2, 1, 2, 1)) - self.assertEqual(mx.squeeze(a).shape, [2, 2]) - self.assertEqual(mx.squeeze(a, 1).shape, [2, 2, 1]) - self.assertEqual(mx.squeeze(a, [1, 3]).shape, [2, 2]) - self.assertEqual(a.squeeze().shape, [2, 2]) - self.assertEqual(a.squeeze(1).shape, [2, 2, 1]) - self.assertEqual(a.squeeze([1, 3]).shape, [2, 2]) + self.assertEqual(mx.squeeze(a).shape, (2, 2)) + self.assertEqual(mx.squeeze(a, 1).shape, (2, 2, 1)) + self.assertEqual(mx.squeeze(a, [1, 3]).shape, (2, 2)) + self.assertEqual(a.squeeze().shape, (2, 2)) + self.assertEqual(a.squeeze(1).shape, (2, 2, 1)) + self.assertEqual(a.squeeze([1, 3]).shape, (2, 2)) a = mx.zeros((2, 2)) - self.assertEqual(mx.squeeze(a).shape, [2, 2]) + self.assertEqual(mx.squeeze(a).shape, (2, 2)) - self.assertEqual(mx.expand_dims(a, 0).shape, [1, 2, 2]) - self.assertEqual(mx.expand_dims(a, (0, 1)).shape, [1, 1, 2, 2]) - self.assertEqual(mx.expand_dims(a, [0, -1]).shape, [1, 2, 2, 1]) + self.assertEqual(mx.expand_dims(a, 0).shape, (1, 2, 2)) + self.assertEqual(mx.expand_dims(a, (0, 1)).shape, (1, 1, 2, 2)) + self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1)) def test_sort(self): shape = (3, 4, 5) @@ -1603,12 +1603,12 @@ class TestOps(mlx_tests.MLXTestCase): def test_flatten(self): x = mx.zeros([2, 3, 4]) - self.assertEqual(mx.flatten(x).shape, [2 * 3 * 4]) - self.assertEqual(mx.flatten(x, start_axis=1).shape, [2, 3 * 4]) - self.assertEqual(mx.flatten(x, end_axis=1).shape, [2 * 3, 4]) - self.assertEqual(x.flatten().shape, [2 * 3 * 4]) - self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4]) - self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4]) + self.assertEqual(mx.flatten(x).shape, (2 * 3 * 4,)) + self.assertEqual(mx.flatten(x, start_axis=1).shape, (2, 3 * 4)) + self.assertEqual(mx.flatten(x, end_axis=1).shape, (2 * 3, 4)) + self.assertEqual(x.flatten().shape, (2 * 3 * 4,)) + self.assertEqual(x.flatten(start_axis=1).shape, (2, 3 * 4)) + self.assertEqual(x.flatten(end_axis=1).shape, (2 * 3, 4)) def test_clip(self): a = np.array([1, 4, 3, 8, 5], np.int32) diff --git a/python/tests/test_random.py b/python/tests/test_random.py index c4ca7f62a..0a06c3496 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -38,19 +38,19 @@ class TestRandom(mlx_tests.MLXTestCase): self.assertTrue(mx.array_equal(k2, r2)) keys = mx.random.split(key, 10) - self.assertEqual(keys.shape, [10, 2]) + self.assertEqual(keys.shape, (10, 2)) def test_uniform(self): key = mx.random.key(0) a = mx.random.uniform(key=key) - self.assertEqual(a.shape, []) + self.assertEqual(a.shape, ()) self.assertEqual(a.dtype, mx.float32) b = mx.random.uniform(key=key) self.assertEqual(a.item(), b.item()) a = mx.random.uniform(shape=(2, 3)) - self.assertEqual(a.shape, [2, 3]) + self.assertEqual(a.shape, (2, 3)) a = mx.random.uniform(shape=(1000,), low=-1, high=5) self.assertTrue(mx.all((a > -1) < 5).item()) @@ -66,14 +66,14 @@ class TestRandom(mlx_tests.MLXTestCase): def test_normal(self): key = mx.random.key(0) a = mx.random.normal(key=key) - self.assertEqual(a.shape, []) + self.assertEqual(a.shape, ()) self.assertEqual(a.dtype, mx.float32) b = mx.random.normal(key=key) self.assertEqual(a.item(), b.item()) a = mx.random.normal(shape=(2, 3)) - self.assertEqual(a.shape, [2, 3]) + self.assertEqual(a.shape, (2, 3)) ## Generate in float16 or bfloat16 for t in [mx.float16, mx.bfloat16]: @@ -84,10 +84,10 @@ class TestRandom(mlx_tests.MLXTestCase): def test_randint(self): a = mx.random.randint(0, 1, []) - self.assertEqual(a.shape, []) + self.assertEqual(a.shape, ()) self.assertEqual(a.dtype, mx.int32) - shape = [88] + shape = (88,) low = mx.array(3) high = mx.array(15) @@ -100,7 +100,7 @@ class TestRandom(mlx_tests.MLXTestCase): b = mx.random.randint(low, high, shape, key=key) self.assertListEqual(a.tolist(), b.tolist()) - shape = [3, 4] + shape = (3, 4) low = mx.reshape(mx.array([0] * 3), [3, 1]) high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4]) @@ -119,20 +119,20 @@ class TestRandom(mlx_tests.MLXTestCase): def test_bernoulli(self): a = mx.random.bernoulli() - self.assertEqual(a.shape, []) + self.assertEqual(a.shape, ()) self.assertEqual(a.dtype, mx.bool_) a = mx.random.bernoulli(mx.array(0.5), [5]) - self.assertEqual(a.shape, [5]) + self.assertEqual(a.shape, (5,)) a = mx.random.bernoulli(mx.array([2.0, -2.0])) self.assertEqual(a.tolist(), [True, False]) - self.assertEqual(a.shape, [2]) + self.assertEqual(a.shape, (2,)) p = mx.array([0.1, 0.2, 0.3]) mx.reshape(p, [1, 3]) x = mx.random.bernoulli(p, [4, 3]) - self.assertEqual(x.shape, [4, 3]) + self.assertEqual(x.shape, (4, 3)) with self.assertRaises(ValueError): mx.random.bernoulli(p, [2]) # Bad shape @@ -153,14 +153,14 @@ class TestRandom(mlx_tests.MLXTestCase): upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1]) a = mx.random.truncated_normal(lower, upper) - self.assertEqual(a.shape, [3, 2]) + self.assertEqual(a.shape, (3, 2)) self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item()) a = mx.random.truncated_normal(2.0, -2.0) self.assertTrue(mx.all(a == 2.0).item()) a = mx.random.truncated_normal(-3.0, 3.0, [542, 399]) - self.assertEqual(a.shape, [542, 399]) + self.assertEqual(a.shape, (542, 399)) lower = mx.array([-2.0, -1.0]) higher = mx.array([1.0, 2.0, 3.0]) @@ -174,7 +174,7 @@ class TestRandom(mlx_tests.MLXTestCase): def test_gumbel(self): samples = mx.random.gumbel(shape=(100, 100)) - self.assertEqual(samples.shape, [100, 100]) + self.assertEqual(samples.shape, (100, 100)) self.assertEqual(samples.dtype, mx.float32) mean = 0.5772 # Std deviation of the sample mean is small (<0.02), @@ -187,23 +187,23 @@ class TestRandom(mlx_tests.MLXTestCase): def test_categorical(self): logits = mx.zeros((10, 20)) - self.assertEqual(mx.random.categorical(logits, -1).shape, [10]) - self.assertEqual(mx.random.categorical(logits, 0).shape, [20]) - self.assertEqual(mx.random.categorical(logits, 1).shape, [10]) + self.assertEqual(mx.random.categorical(logits, -1).shape, (10,)) + self.assertEqual(mx.random.categorical(logits, 0).shape, (20,)) + self.assertEqual(mx.random.categorical(logits, 1).shape, (10,)) out = mx.random.categorical(logits) - self.assertEqual(out.shape, [10]) + self.assertEqual(out.shape, (10,)) self.assertEqual(out.dtype, mx.uint32) self.assertTrue(mx.max(out).item() < 20) out = mx.random.categorical(logits, 0, [5, 20]) - self.assertEqual(out.shape, [5, 20]) + self.assertEqual(out.shape, (5, 20)) self.assertTrue(mx.max(out).item() < 10) out = mx.random.categorical(logits, 1, num_samples=7) - self.assertEqual(out.shape, [10, 7]) + self.assertEqual(out.shape, (10, 7)) out = mx.random.categorical(logits, 0, num_samples=7) - self.assertEqual(out.shape, [20, 7]) + self.assertEqual(out.shape, (20, 7)) with self.assertRaises(ValueError): mx.random.categorical(logits, shape=[10, 5], num_samples=5)