Make shape a tuple (#591)

* shape tuple

* also remove simplify from docs

* rebase
This commit is contained in:
Awni Hannun 2024-01-30 13:11:01 -08:00 committed by GitHub
parent d3a9005454
commit 09b9275027
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 141 additions and 140 deletions

View File

@ -929,7 +929,7 @@ We see some modest improvements right away!
This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations such as :meth:`grad` and :meth:`simplify`!
transformations like :meth:`grad`!
Scripts
-------

View File

@ -14,4 +14,3 @@ Transforms
jvp
vjp
vmap
simplify

View File

@ -20,7 +20,7 @@ Transforming Compute Graphs
Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations like :func:`simplify`.
:func:`vmap` and graph optimizations.
Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to

View File

@ -81,7 +81,7 @@ class Dropout2d(Module):
# Dropout is applied on the whole channel
# 3D input: (1, 1, C)
# 4D input: (B, 1, 1, C)
mask_shape = x.shape
mask_shape = list(x.shape)
mask_shape[-2] = mask_shape[-3] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)

View File

@ -70,8 +70,9 @@ def cross_entropy(
targets_as_probs = targets.ndim == logits.ndim
def _drop_dim(shape, axis):
shape = list(shape)
shape.pop(axis)
return shape
return tuple(shape)
# Check shapes in two cases: targets as class indices and targets as probabilities
if (targets_as_probs and targets.shape != logits.shape) or (

View File

@ -675,17 +675,14 @@ void init_array(py::module_& m) {
"nbytes",
&array::nbytes,
R"pbdoc(The number of bytes in the array.)pbdoc")
// TODO, this makes a deep copy of the shape
// implement alternatives to use reference
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
.def_property_readonly(
"shape",
[](const array& a) { return a.shape(); },
[](const array& a) { return py::tuple(py::cast(a.shape())); },
R"pbdoc(
The shape of the array as a Python list.
Returns:
list(int): A list containing the sizes of each dimension.
tuple(int): A tuple containing the sizes of each dimension.
)pbdoc")
.def_property_readonly(
"dtype",

View File

@ -94,7 +94,7 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(x.ndim, 0)
self.assertEqual(x.itemsize, 4)
self.assertEqual(x.nbytes, 4)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), int))
@ -116,7 +116,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(1.0)
self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.item(), 1.0)
self.assertTrue(isinstance(x.item(), float))
@ -124,14 +124,14 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(False)
self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.bool_)
self.assertEqual(x.item(), False)
self.assertTrue(isinstance(x.item(), bool))
x = mx.array(complex(1, 1))
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.complex64)
self.assertEqual(x.item(), complex(1, 1))
self.assertTrue(isinstance(x.item(), complex))
@ -139,7 +139,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([True, False, True])
self.assertEqual(x.dtype, mx.bool_)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [3])
self.assertEqual(x.shape, (3,))
self.assertEqual(len(x), 3)
x = mx.array([True, False, True], mx.float32)
@ -148,7 +148,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([0, 1, 2])
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [3])
self.assertEqual(x.shape, (3,))
x = mx.array([0, 1, 2], mx.float32)
self.assertEqual(x.dtype, mx.float32)
@ -156,12 +156,12 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([0.0, 1.0, 2.0])
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [3])
self.assertEqual(x.shape, (3,))
x = mx.array([1j, 1 + 0j])
self.assertEqual(x.dtype, mx.complex64)
self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [2])
self.assertEqual(x.shape, (2,))
# From tuple
x = mx.array((1, 2, 3), mx.int32)
@ -181,17 +181,17 @@ class TestArray(mlx_tests.MLXTestCase):
def test_construction_from_lists(self):
x = mx.array([])
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [0])
self.assertEqual(x.shape, (0,))
self.assertEqual(x.dtype, mx.float32)
x = mx.array([[], [], []])
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 0])
self.assertEqual(x.shape, (3, 0))
self.assertEqual(x.dtype, mx.float32)
x = mx.array([[[], []], [[], []], [[], []]])
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 2, 0])
self.assertEqual(x.shape, (3, 2, 0))
self.assertEqual(x.dtype, mx.float32)
# Check failure cases
@ -436,19 +436,19 @@ class TestArray(mlx_tests.MLXTestCase):
a = np.array([])
x = mx.array(a)
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [0])
self.assertEqual(x.shape, (0,))
self.assertEqual(x.dtype, mx.float32)
a = np.array([[], [], []])
x = mx.array(a)
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 0])
self.assertEqual(x.shape, (3, 0))
self.assertEqual(x.dtype, mx.float32)
a = np.array([[[], []], [[], []], [[], []]])
x = mx.array(a)
self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 2, 0])
self.assertEqual(x.shape, (3, 2, 0))
self.assertEqual(x.dtype, mx.float32)
# Content test
@ -456,7 +456,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(a)
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.ndim, 3)
self.assertEqual(x.shape, [3, 5, 4])
self.assertEqual(x.shape, (3, 5, 4))
y = np.asarray(x)
self.assertTrue(np.allclose(a, y))
@ -465,7 +465,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(a)
self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, [])
self.assertEqual(x.shape, ())
self.assertEqual(x.item(), 3)
# mlx to numpy test
@ -483,7 +483,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = np.array(cvals)
y = mx.array(x)
self.assertEqual(y.dtype, mx.complex64)
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.tolist(), cvals)
y = mx.array([0j, 1, 1 + 1j])

View File

@ -579,5 +579,9 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
for r, t in zip(dout_ref, dout_test):
self.assertListEqual(r.shape, t.shape)
self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
if __name__ == "__main__":
unittest.main()

View File

@ -44,7 +44,7 @@ class TestConv(mlx_tests.MLXTestCase):
c_np = np.convolve(a_np, v_np, mode=mode)
c_mx = mx.convolve(a_mx, v_mx, mode=mode)
self.assertListEqual(list(c_mx.shape), list(c_np.shape))
self.assertEqual(c_mx.shape, c_np.shape)
self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
@unittest.skipIf(not has_torch, "requires Torch")
@ -102,7 +102,7 @@ class TestConv(mlx_tests.MLXTestCase):
)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",):
@ -141,7 +141,7 @@ class TestConv(mlx_tests.MLXTestCase):
out_pt = torch.conv1d(in_pt, wt_pt)
out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape)
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
@unittest.skipIf(not has_torch, "requires Torch")
@ -228,12 +228,12 @@ class TestConv(mlx_tests.MLXTestCase):
mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):
@ -309,7 +309,7 @@ class TestConv(mlx_tests.MLXTestCase):
)
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertListEqual(list(out_pt.shape), list(out_mx.shape))
self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",):
@ -419,12 +419,12 @@ class TestConv(mlx_tests.MLXTestCase):
mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape)
self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape)
self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",):

View File

@ -13,7 +13,7 @@ class TestInit(mlx_tests.MLXTestCase):
for dtype in [mx.float32, mx.float16]:
initializer = init.constant(value, dtype)
for shape in [[3], [3, 3], [3, 3, 3]]:
for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.array(mx.zeros(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@ -24,7 +24,7 @@ class TestInit(mlx_tests.MLXTestCase):
std = 1.0
for dtype in [mx.float32, mx.float16]:
initializer = init.normal(mean, std, dtype=dtype)
for shape in [[3], [3, 3], [3, 3, 3]]:
for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@ -36,7 +36,7 @@ class TestInit(mlx_tests.MLXTestCase):
for dtype in [mx.float32, mx.float16]:
initializer = init.uniform(low, high, dtype)
for shape in [[3], [3, 3], [3, 3, 3]]:
for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@ -46,7 +46,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_identity(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.identity(dtype)
for shape in [[3], [3, 3], [3, 3, 3]]:
for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.zeros((3, 3)))
self.assertTrue(mx.array_equal(result, mx.eye(3)))
self.assertEqual(result.dtype, dtype)
@ -56,7 +56,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_glorot_normal(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.glorot_normal(dtype)
for shape in [[3, 3], [3, 3, 3]]:
for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@ -65,7 +65,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_glorot_uniform(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.glorot_uniform(dtype)
for shape in [[3, 3], [3, 3, 3]]:
for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@ -74,7 +74,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_he_normal(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.he_normal(dtype)
for shape in [[3, 3], [3, 3, 3]]:
for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)
@ -83,7 +83,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_he_uniform(self):
for dtype in [mx.float32, mx.float16]:
initializer = init.he_uniform(dtype)
for shape in [[3, 3], [3, 3, 3]]:
for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape):
self.assertEqual(result.shape, shape)

View File

@ -136,20 +136,20 @@ class TestLayers(mlx_tests.MLXTestCase):
inputs = mx.zeros((10, 4))
layer = nn.Identity()
outputs = layer(inputs)
self.assertEqual(tuple(inputs.shape), tuple(outputs.shape))
self.assertEqual(inputs.shape, outputs.shape)
def test_linear(self):
inputs = mx.zeros((10, 4))
layer = nn.Linear(input_dims=4, output_dims=8)
outputs = layer(inputs)
self.assertEqual(tuple(outputs.shape), (10, 8))
self.assertEqual(outputs.shape, (10, 8))
def test_bilinear(self):
inputs1 = mx.zeros((10, 2))
inputs2 = mx.zeros((10, 4))
layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)
outputs = layer(inputs1, inputs2)
self.assertEqual(tuple(outputs.shape), (10, 6))
self.assertEqual(outputs.shape, (10, 6))
def test_group_norm(self):
x = mx.arange(100, dtype=mx.float32)
@ -573,12 +573,12 @@ class TestLayers(mlx_tests.MLXTestCase):
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
c.weight = mx.ones_like(c.weight)
y = c(x)
self.assertEqual(y.shape, [N, L - ks + 1, C_out])
self.assertEqual(y.shape, (N, L - ks + 1, C_out))
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
y = c(x)
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out])
self.assertEqual(y.shape, (N, (L - ks + 1) // 2, C_out))
self.assertTrue("bias" in c.parameters())
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
@ -588,7 +588,7 @@ class TestLayers(mlx_tests.MLXTestCase):
x = mx.ones((4, 8, 8, 3))
c = nn.Conv2d(3, 1, 8)
y = c(x)
self.assertEqual(y.shape, [4, 1, 1, 1])
self.assertEqual(y.shape, (4, 1, 1, 1))
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
y = c(x)
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
@ -596,13 +596,13 @@ class TestLayers(mlx_tests.MLXTestCase):
# 3x3 conv no padding stride 1
c = nn.Conv2d(3, 8, 3)
y = c(x)
self.assertEqual(y.shape, [4, 6, 6, 8])
self.assertEqual(y.shape, (4, 6, 6, 8))
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
# 3x3 conv padding 1 stride 1
c = nn.Conv2d(3, 8, 3, padding=1)
y = c(x)
self.assertEqual(y.shape, [4, 8, 8, 8])
self.assertEqual(y.shape, (4, 8, 8, 8))
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
self.assertLess(
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
@ -624,14 +624,14 @@ class TestLayers(mlx_tests.MLXTestCase):
# 3x3 conv no padding stride 2
c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
y = c(x)
self.assertEqual(y.shape, [4, 3, 3, 8])
self.assertEqual(y.shape, (4, 3, 3, 8))
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
def test_sequential(self):
x = mx.ones((10, 2))
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
y = m(x)
self.assertEqual(y.shape, [10, 1])
self.assertEqual(y.shape, (10, 1))
params = m.parameters()
self.assertTrue("layers" in params)
self.assertEqual(len(params["layers"]), 3)
@ -667,7 +667,7 @@ class TestLayers(mlx_tests.MLXTestCase):
x = mx.arange(10)
y = m(x)
self.assertEqual(y.shape, [10, 16])
self.assertEqual(y.shape, (10, 16))
similarities = y @ y.T
self.assertLess(
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
@ -686,19 +686,19 @@ class TestLayers(mlx_tests.MLXTestCase):
x = mx.array([1.0, -1.0, 0.0])
y = nn.relu(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0])))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_leaky_relu(self):
x = mx.array([1.0, -1.0, 0.0])
y = nn.leaky_relu(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0])))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
y = nn.LeakyReLU(negative_slope=0.1)(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0])))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_elu(self):
@ -707,21 +707,21 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4
expected_y = mx.array([1.0, -0.6321, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
y = nn.ELU(alpha=1.1)(x)
epsilon = 1e-4
expected_y = mx.array([1.0, -0.6953, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_relu6(self):
x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0])
y = nn.relu6(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0])))
self.assertEqual(y.shape, [5])
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_softmax(self):
@ -730,7 +730,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4
expected_y = mx.array([0.6652, 0.0900, 0.2447])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softplus(self):
@ -739,7 +739,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4
expected_y = mx.array([1.3133, 0.3133, 0.6931])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softsign(self):
@ -748,7 +748,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4
expected_y = mx.array([0.5, -0.5, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_softshrink(self):
@ -757,13 +757,13 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4
expected_y = mx.array([0.5, -0.5, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
y = nn.Softshrink(lambd=0.7)(x)
expected_y = mx.array([0.3, -0.3, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_celu(self):
@ -772,13 +772,13 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4
expected_y = mx.array([1.0, -0.6321, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
y = nn.CELU(alpha=1.1)(x)
expected_y = mx.array([1.0, -0.6568, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_log_softmax(self):
@ -787,7 +787,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4
expected_y = mx.array([-2.4076, -1.4076, -0.4076])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_log_sigmoid(self):
@ -796,7 +796,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4
expected_y = mx.array([-0.3133, -1.3133, -0.6931])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3])
self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32)
def test_prelu(self):
@ -817,7 +817,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4
expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [5])
self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32)
def test_glu(self):

View File

@ -12,12 +12,12 @@ import numpy as np
class TestOps(mlx_tests.MLXTestCase):
def test_full_ones_zeros(self):
x = mx.full(2, 3.0)
self.assertEqual(x.shape, [2])
self.assertEqual(x.shape, (2,))
self.assertEqual(x.tolist(), [3.0, 3.0])
x = mx.full((2, 3), 2.0)
self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.shape, [2, 3])
self.assertEqual(x.shape, (2, 3))
self.assertEqual(x.tolist(), [[2, 2, 2], [2, 2, 2]])
x = mx.full([3, 2], mx.array([False, True]))
@ -28,11 +28,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(x.tolist(), [[2, 3], [2, 3], [2, 3]])
x = mx.zeros(2)
self.assertEqual(x.shape, [2])
self.assertEqual(x.shape, (2,))
self.assertEqual(x.tolist(), [0.0, 0.0])
x = mx.ones(2)
self.assertEqual(x.shape, [2])
self.assertEqual(x.shape, (2,))
self.assertEqual(x.tolist(), [1.0, 1.0])
for t in [mx.bool_, mx.int32, mx.float32]:
@ -530,10 +530,10 @@ class TestOps(mlx_tests.MLXTestCase):
def test_move_swap_axes(self):
x = mx.zeros((2, 3, 4))
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2])
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2])
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2])
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2])
self.assertEqual(mx.moveaxis(x, 0, 2).shape, (3, 4, 2))
self.assertEqual(x.moveaxis(0, 2).shape, (3, 4, 2))
self.assertEqual(mx.swapaxes(x, 0, 2).shape, (4, 3, 2))
self.assertEqual(x.swapaxes(0, 2).shape, (4, 3, 2))
def test_sum(self):
x = mx.array(
@ -545,7 +545,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.sum(x).item(), 9)
y = mx.sum(x, keepdims=True)
self.assertEqual(y, mx.array(9))
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.sum(x, axis=0).tolist(), [4, 5])
self.assertEqual(mx.sum(x, axis=1).tolist(), [3, 6])
@ -585,7 +585,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.prod(x).item(), 18)
y = mx.prod(x, keepdims=True)
self.assertEqual(y, mx.array(18))
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.prod(x, axis=0).tolist(), [3, 6])
self.assertEqual(mx.prod(x, axis=1).tolist(), [2, 9])
@ -600,11 +600,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.min(x).item(), 1)
self.assertEqual(mx.max(x).item(), 4)
y = mx.min(x, keepdims=True)
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(y, mx.array(1))
y = mx.max(x, keepdims=True)
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(y, mx.array(4))
self.assertEqual(mx.min(x, axis=0).tolist(), [1, 2])
@ -670,7 +670,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.mean(x).item(), 2.5)
y = mx.mean(x, keepdims=True)
self.assertEqual(y, mx.array(2.5))
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
@ -685,7 +685,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.var(x).item(), 1.25)
y = mx.var(x, keepdims=True)
self.assertEqual(y, mx.array(1.25))
self.assertEqual(y.shape, [1, 1])
self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0])
self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25])
@ -888,7 +888,7 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.array([[True, False], [True, True]])
self.assertFalse(mx.all(a).item())
self.assertEqual(mx.all(a, keepdims=True).shape, [1, 1])
self.assertEqual(mx.all(a, keepdims=True).shape, (1, 1))
self.assertFalse(mx.all(a, axis=[0, 1]).item())
self.assertEqual(mx.all(a, axis=[0]).tolist(), [True, False])
self.assertEqual(mx.all(a, axis=[1]).tolist(), [False, True])
@ -899,7 +899,7 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.array([[True, False], [False, False]])
self.assertTrue(mx.any(a).item())
self.assertEqual(mx.any(a, keepdims=True).shape, [1, 1])
self.assertEqual(mx.any(a, keepdims=True).shape, (1, 1))
self.assertTrue(mx.any(a, axis=[0, 1]).item())
self.assertEqual(mx.any(a, axis=[0]).tolist(), [True, False])
self.assertEqual(mx.any(a, axis=[1]).tolist(), [True, False])
@ -956,22 +956,22 @@ class TestOps(mlx_tests.MLXTestCase):
a_npy_taken = np.take(a_npy, idx_npy)
a_mlx_taken = mx.take(a_mlx, idx_mlx)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
a_npy_taken = np.take(a_npy, idx_npy, axis=0)
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=0)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
a_npy_taken = np.take(a_npy, idx_npy, axis=1)
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=1)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
a_npy_taken = np.take(a_npy, idx_npy, axis=2)
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=2)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape)
self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
def test_take_along_axis(self):
@ -1400,13 +1400,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
a = mx.zeros((1, 1, 1))
self.assertEqual(mx.pad(a, 1).shape, [3, 3, 3])
self.assertEqual(mx.pad(a, (1,)).shape, [3, 3, 3])
self.assertEqual(mx.pad(a, [1]).shape, [3, 3, 3])
self.assertEqual(mx.pad(a, (1, 2)).shape, [4, 4, 4])
self.assertEqual(mx.pad(a, [(1, 2)]).shape, [4, 4, 4])
self.assertEqual(mx.pad(a, ((1, 2),)).shape, [4, 4, 4])
self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, [4, 4, 5])
self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, [1]).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, (1, 2)).shape, (4, 4, 4))
self.assertEqual(mx.pad(a, [(1, 2)]).shape, (4, 4, 4))
self.assertEqual(mx.pad(a, ((1, 2),)).shape, (4, 4, 4))
self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, (4, 4, 5))
# Test grads
a_fwd = mx.array(np.random.rand(16, 16).astype(np.float32))
@ -1490,19 +1490,19 @@ class TestOps(mlx_tests.MLXTestCase):
def test_squeeze_expand(self):
a = mx.zeros((2, 1, 2, 1))
self.assertEqual(mx.squeeze(a).shape, [2, 2])
self.assertEqual(mx.squeeze(a, 1).shape, [2, 2, 1])
self.assertEqual(mx.squeeze(a, [1, 3]).shape, [2, 2])
self.assertEqual(a.squeeze().shape, [2, 2])
self.assertEqual(a.squeeze(1).shape, [2, 2, 1])
self.assertEqual(a.squeeze([1, 3]).shape, [2, 2])
self.assertEqual(mx.squeeze(a).shape, (2, 2))
self.assertEqual(mx.squeeze(a, 1).shape, (2, 2, 1))
self.assertEqual(mx.squeeze(a, [1, 3]).shape, (2, 2))
self.assertEqual(a.squeeze().shape, (2, 2))
self.assertEqual(a.squeeze(1).shape, (2, 2, 1))
self.assertEqual(a.squeeze([1, 3]).shape, (2, 2))
a = mx.zeros((2, 2))
self.assertEqual(mx.squeeze(a).shape, [2, 2])
self.assertEqual(mx.squeeze(a).shape, (2, 2))
self.assertEqual(mx.expand_dims(a, 0).shape, [1, 2, 2])
self.assertEqual(mx.expand_dims(a, (0, 1)).shape, [1, 1, 2, 2])
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, [1, 2, 2, 1])
self.assertEqual(mx.expand_dims(a, 0).shape, (1, 2, 2))
self.assertEqual(mx.expand_dims(a, (0, 1)).shape, (1, 1, 2, 2))
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
def test_sort(self):
shape = (3, 4, 5)
@ -1603,12 +1603,12 @@ class TestOps(mlx_tests.MLXTestCase):
def test_flatten(self):
x = mx.zeros([2, 3, 4])
self.assertEqual(mx.flatten(x).shape, [2 * 3 * 4])
self.assertEqual(mx.flatten(x, start_axis=1).shape, [2, 3 * 4])
self.assertEqual(mx.flatten(x, end_axis=1).shape, [2 * 3, 4])
self.assertEqual(x.flatten().shape, [2 * 3 * 4])
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
self.assertEqual(mx.flatten(x).shape, (2 * 3 * 4,))
self.assertEqual(mx.flatten(x, start_axis=1).shape, (2, 3 * 4))
self.assertEqual(mx.flatten(x, end_axis=1).shape, (2 * 3, 4))
self.assertEqual(x.flatten().shape, (2 * 3 * 4,))
self.assertEqual(x.flatten(start_axis=1).shape, (2, 3 * 4))
self.assertEqual(x.flatten(end_axis=1).shape, (2 * 3, 4))
def test_clip(self):
a = np.array([1, 4, 3, 8, 5], np.int32)

View File

@ -38,19 +38,19 @@ class TestRandom(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(k2, r2))
keys = mx.random.split(key, 10)
self.assertEqual(keys.shape, [10, 2])
self.assertEqual(keys.shape, (10, 2))
def test_uniform(self):
key = mx.random.key(0)
a = mx.random.uniform(key=key)
self.assertEqual(a.shape, [])
self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.float32)
b = mx.random.uniform(key=key)
self.assertEqual(a.item(), b.item())
a = mx.random.uniform(shape=(2, 3))
self.assertEqual(a.shape, [2, 3])
self.assertEqual(a.shape, (2, 3))
a = mx.random.uniform(shape=(1000,), low=-1, high=5)
self.assertTrue(mx.all((a > -1) < 5).item())
@ -66,14 +66,14 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_normal(self):
key = mx.random.key(0)
a = mx.random.normal(key=key)
self.assertEqual(a.shape, [])
self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.float32)
b = mx.random.normal(key=key)
self.assertEqual(a.item(), b.item())
a = mx.random.normal(shape=(2, 3))
self.assertEqual(a.shape, [2, 3])
self.assertEqual(a.shape, (2, 3))
## Generate in float16 or bfloat16
for t in [mx.float16, mx.bfloat16]:
@ -84,10 +84,10 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_randint(self):
a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, [])
self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.int32)
shape = [88]
shape = (88,)
low = mx.array(3)
high = mx.array(15)
@ -100,7 +100,7 @@ class TestRandom(mlx_tests.MLXTestCase):
b = mx.random.randint(low, high, shape, key=key)
self.assertListEqual(a.tolist(), b.tolist())
shape = [3, 4]
shape = (3, 4)
low = mx.reshape(mx.array([0] * 3), [3, 1])
high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])
@ -119,20 +119,20 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_bernoulli(self):
a = mx.random.bernoulli()
self.assertEqual(a.shape, [])
self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.bool_)
a = mx.random.bernoulli(mx.array(0.5), [5])
self.assertEqual(a.shape, [5])
self.assertEqual(a.shape, (5,))
a = mx.random.bernoulli(mx.array([2.0, -2.0]))
self.assertEqual(a.tolist(), [True, False])
self.assertEqual(a.shape, [2])
self.assertEqual(a.shape, (2,))
p = mx.array([0.1, 0.2, 0.3])
mx.reshape(p, [1, 3])
x = mx.random.bernoulli(p, [4, 3])
self.assertEqual(x.shape, [4, 3])
self.assertEqual(x.shape, (4, 3))
with self.assertRaises(ValueError):
mx.random.bernoulli(p, [2]) # Bad shape
@ -153,14 +153,14 @@ class TestRandom(mlx_tests.MLXTestCase):
upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])
a = mx.random.truncated_normal(lower, upper)
self.assertEqual(a.shape, [3, 2])
self.assertEqual(a.shape, (3, 2))
self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())
a = mx.random.truncated_normal(2.0, -2.0)
self.assertTrue(mx.all(a == 2.0).item())
a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])
self.assertEqual(a.shape, [542, 399])
self.assertEqual(a.shape, (542, 399))
lower = mx.array([-2.0, -1.0])
higher = mx.array([1.0, 2.0, 3.0])
@ -174,7 +174,7 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_gumbel(self):
samples = mx.random.gumbel(shape=(100, 100))
self.assertEqual(samples.shape, [100, 100])
self.assertEqual(samples.shape, (100, 100))
self.assertEqual(samples.dtype, mx.float32)
mean = 0.5772
# Std deviation of the sample mean is small (<0.02),
@ -187,23 +187,23 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_categorical(self):
logits = mx.zeros((10, 20))
self.assertEqual(mx.random.categorical(logits, -1).shape, [10])
self.assertEqual(mx.random.categorical(logits, 0).shape, [20])
self.assertEqual(mx.random.categorical(logits, 1).shape, [10])
self.assertEqual(mx.random.categorical(logits, -1).shape, (10,))
self.assertEqual(mx.random.categorical(logits, 0).shape, (20,))
self.assertEqual(mx.random.categorical(logits, 1).shape, (10,))
out = mx.random.categorical(logits)
self.assertEqual(out.shape, [10])
self.assertEqual(out.shape, (10,))
self.assertEqual(out.dtype, mx.uint32)
self.assertTrue(mx.max(out).item() < 20)
out = mx.random.categorical(logits, 0, [5, 20])
self.assertEqual(out.shape, [5, 20])
self.assertEqual(out.shape, (5, 20))
self.assertTrue(mx.max(out).item() < 10)
out = mx.random.categorical(logits, 1, num_samples=7)
self.assertEqual(out.shape, [10, 7])
self.assertEqual(out.shape, (10, 7))
out = mx.random.categorical(logits, 0, num_samples=7)
self.assertEqual(out.shape, [20, 7])
self.assertEqual(out.shape, (20, 7))
with self.assertRaises(ValueError):
mx.random.categorical(logits, shape=[10, 5], num_samples=5)