Make shape a tuple (#591)

* shape tuple

* also remove simplify from docs

* rebase
This commit is contained in:
Awni Hannun 2024-01-30 13:11:01 -08:00 committed by GitHub
parent d3a9005454
commit 09b9275027
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 141 additions and 140 deletions

View File

@ -929,7 +929,7 @@ We see some modest improvements right away!
This operation is now good to be used to build other operations, This operation is now good to be used to build other operations,
in :class:`mlx.nn.Module` calls, and also as a part of graph in :class:`mlx.nn.Module` calls, and also as a part of graph
transformations such as :meth:`grad` and :meth:`simplify`! transformations like :meth:`grad`!
Scripts Scripts
------- -------

View File

@ -14,4 +14,3 @@ Transforms
jvp jvp
vjp vjp
vmap vmap
simplify

View File

@ -20,7 +20,7 @@ Transforming Compute Graphs
Lazy evaluation let's us record a compute graph without actually doing any Lazy evaluation let's us record a compute graph without actually doing any
computations. This is useful for function transformations like :func:`grad` and computations. This is useful for function transformations like :func:`grad` and
:func:`vmap` and graph optimizations like :func:`simplify`. :func:`vmap` and graph optimizations.
Currently, MLX does not compile and rerun compute graphs. They are all Currently, MLX does not compile and rerun compute graphs. They are all
generated dynamically. However, lazy evaluation makes it much easier to generated dynamically. However, lazy evaluation makes it much easier to

View File

@ -81,7 +81,7 @@ class Dropout2d(Module):
# Dropout is applied on the whole channel # Dropout is applied on the whole channel
# 3D input: (1, 1, C) # 3D input: (1, 1, C)
# 4D input: (B, 1, 1, C) # 4D input: (B, 1, 1, C)
mask_shape = x.shape mask_shape = list(x.shape)
mask_shape[-2] = mask_shape[-3] = 1 mask_shape[-2] = mask_shape[-3] = 1
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)

View File

@ -70,8 +70,9 @@ def cross_entropy(
targets_as_probs = targets.ndim == logits.ndim targets_as_probs = targets.ndim == logits.ndim
def _drop_dim(shape, axis): def _drop_dim(shape, axis):
shape = list(shape)
shape.pop(axis) shape.pop(axis)
return shape return tuple(shape)
# Check shapes in two cases: targets as class indices and targets as probabilities # Check shapes in two cases: targets as class indices and targets as probabilities
if (targets_as_probs and targets.shape != logits.shape) or ( if (targets_as_probs and targets.shape != logits.shape) or (

View File

@ -675,17 +675,14 @@ void init_array(py::module_& m) {
"nbytes", "nbytes",
&array::nbytes, &array::nbytes,
R"pbdoc(The number of bytes in the array.)pbdoc") R"pbdoc(The number of bytes in the array.)pbdoc")
// TODO, this makes a deep copy of the shape
// implement alternatives to use reference
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
.def_property_readonly( .def_property_readonly(
"shape", "shape",
[](const array& a) { return a.shape(); }, [](const array& a) { return py::tuple(py::cast(a.shape())); },
R"pbdoc( R"pbdoc(
The shape of the array as a Python list. The shape of the array as a Python list.
Returns: Returns:
list(int): A list containing the sizes of each dimension. tuple(int): A tuple containing the sizes of each dimension.
)pbdoc") )pbdoc")
.def_property_readonly( .def_property_readonly(
"dtype", "dtype",

View File

@ -94,7 +94,7 @@ class TestArray(mlx_tests.MLXTestCase):
self.assertEqual(x.ndim, 0) self.assertEqual(x.ndim, 0)
self.assertEqual(x.itemsize, 4) self.assertEqual(x.itemsize, 4)
self.assertEqual(x.nbytes, 4) self.assertEqual(x.nbytes, 4)
self.assertEqual(x.shape, []) self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.int32) self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.item(), 1) self.assertEqual(x.item(), 1)
self.assertTrue(isinstance(x.item(), int)) self.assertTrue(isinstance(x.item(), int))
@ -116,7 +116,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(1.0) x = mx.array(1.0)
self.assertEqual(x.size, 1) self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0) self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, []) self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.item(), 1.0) self.assertEqual(x.item(), 1.0)
self.assertTrue(isinstance(x.item(), float)) self.assertTrue(isinstance(x.item(), float))
@ -124,14 +124,14 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(False) x = mx.array(False)
self.assertEqual(x.size, 1) self.assertEqual(x.size, 1)
self.assertEqual(x.ndim, 0) self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, []) self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.bool_) self.assertEqual(x.dtype, mx.bool_)
self.assertEqual(x.item(), False) self.assertEqual(x.item(), False)
self.assertTrue(isinstance(x.item(), bool)) self.assertTrue(isinstance(x.item(), bool))
x = mx.array(complex(1, 1)) x = mx.array(complex(1, 1))
self.assertEqual(x.ndim, 0) self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, []) self.assertEqual(x.shape, ())
self.assertEqual(x.dtype, mx.complex64) self.assertEqual(x.dtype, mx.complex64)
self.assertEqual(x.item(), complex(1, 1)) self.assertEqual(x.item(), complex(1, 1))
self.assertTrue(isinstance(x.item(), complex)) self.assertTrue(isinstance(x.item(), complex))
@ -139,7 +139,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([True, False, True]) x = mx.array([True, False, True])
self.assertEqual(x.dtype, mx.bool_) self.assertEqual(x.dtype, mx.bool_)
self.assertEqual(x.ndim, 1) self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [3]) self.assertEqual(x.shape, (3,))
self.assertEqual(len(x), 3) self.assertEqual(len(x), 3)
x = mx.array([True, False, True], mx.float32) x = mx.array([True, False, True], mx.float32)
@ -148,7 +148,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([0, 1, 2]) x = mx.array([0, 1, 2])
self.assertEqual(x.dtype, mx.int32) self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.ndim, 1) self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [3]) self.assertEqual(x.shape, (3,))
x = mx.array([0, 1, 2], mx.float32) x = mx.array([0, 1, 2], mx.float32)
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
@ -156,12 +156,12 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array([0.0, 1.0, 2.0]) x = mx.array([0.0, 1.0, 2.0])
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.ndim, 1) self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [3]) self.assertEqual(x.shape, (3,))
x = mx.array([1j, 1 + 0j]) x = mx.array([1j, 1 + 0j])
self.assertEqual(x.dtype, mx.complex64) self.assertEqual(x.dtype, mx.complex64)
self.assertEqual(x.ndim, 1) self.assertEqual(x.ndim, 1)
self.assertEqual(x.shape, [2]) self.assertEqual(x.shape, (2,))
# From tuple # From tuple
x = mx.array((1, 2, 3), mx.int32) x = mx.array((1, 2, 3), mx.int32)
@ -181,17 +181,17 @@ class TestArray(mlx_tests.MLXTestCase):
def test_construction_from_lists(self): def test_construction_from_lists(self):
x = mx.array([]) x = mx.array([])
self.assertEqual(x.size, 0) self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [0]) self.assertEqual(x.shape, (0,))
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
x = mx.array([[], [], []]) x = mx.array([[], [], []])
self.assertEqual(x.size, 0) self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 0]) self.assertEqual(x.shape, (3, 0))
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
x = mx.array([[[], []], [[], []], [[], []]]) x = mx.array([[[], []], [[], []], [[], []]])
self.assertEqual(x.size, 0) self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 2, 0]) self.assertEqual(x.shape, (3, 2, 0))
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
# Check failure cases # Check failure cases
@ -436,19 +436,19 @@ class TestArray(mlx_tests.MLXTestCase):
a = np.array([]) a = np.array([])
x = mx.array(a) x = mx.array(a)
self.assertEqual(x.size, 0) self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [0]) self.assertEqual(x.shape, (0,))
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
a = np.array([[], [], []]) a = np.array([[], [], []])
x = mx.array(a) x = mx.array(a)
self.assertEqual(x.size, 0) self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 0]) self.assertEqual(x.shape, (3, 0))
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
a = np.array([[[], []], [[], []], [[], []]]) a = np.array([[[], []], [[], []], [[], []]])
x = mx.array(a) x = mx.array(a)
self.assertEqual(x.size, 0) self.assertEqual(x.size, 0)
self.assertEqual(x.shape, [3, 2, 0]) self.assertEqual(x.shape, (3, 2, 0))
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
# Content test # Content test
@ -456,7 +456,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(a) x = mx.array(a)
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.ndim, 3) self.assertEqual(x.ndim, 3)
self.assertEqual(x.shape, [3, 5, 4]) self.assertEqual(x.shape, (3, 5, 4))
y = np.asarray(x) y = np.asarray(x)
self.assertTrue(np.allclose(a, y)) self.assertTrue(np.allclose(a, y))
@ -465,7 +465,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = mx.array(a) x = mx.array(a)
self.assertEqual(x.dtype, mx.int32) self.assertEqual(x.dtype, mx.int32)
self.assertEqual(x.ndim, 0) self.assertEqual(x.ndim, 0)
self.assertEqual(x.shape, []) self.assertEqual(x.shape, ())
self.assertEqual(x.item(), 3) self.assertEqual(x.item(), 3)
# mlx to numpy test # mlx to numpy test
@ -483,7 +483,7 @@ class TestArray(mlx_tests.MLXTestCase):
x = np.array(cvals) x = np.array(cvals)
y = mx.array(x) y = mx.array(x)
self.assertEqual(y.dtype, mx.complex64) self.assertEqual(y.dtype, mx.complex64)
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.tolist(), cvals) self.assertEqual(y.tolist(), cvals)
y = mx.array([0j, 1, 1 + 1j]) y = mx.array([0j, 1, 1 + 1j])

View File

@ -579,5 +579,9 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item()) self.assertTrue(mx.allclose(out_ref[0], out_test[0], atol=1e-4).item())
for r, t in zip(dout_ref, dout_test): for r, t in zip(dout_ref, dout_test):
self.assertListEqual(r.shape, t.shape) self.assertEqual(r.shape, t.shape)
self.assertTrue(mx.allclose(r, t, atol=1e-4).item()) self.assertTrue(mx.allclose(r, t, atol=1e-4).item())
if __name__ == "__main__":
unittest.main()

View File

@ -44,7 +44,7 @@ class TestConv(mlx_tests.MLXTestCase):
c_np = np.convolve(a_np, v_np, mode=mode) c_np = np.convolve(a_np, v_np, mode=mode)
c_mx = mx.convolve(a_mx, v_mx, mode=mode) c_mx = mx.convolve(a_mx, v_mx, mode=mode)
self.assertListEqual(list(c_mx.shape), list(c_np.shape)) self.assertEqual(c_mx.shape, c_np.shape)
self.assertTrue(np.allclose(c_mx, c_np, atol=atol)) self.assertTrue(np.allclose(c_mx, c_np, atol=atol))
@unittest.skipIf(not has_torch, "requires Torch") @unittest.skipIf(not has_torch, "requires Torch")
@ -102,7 +102,7 @@ class TestConv(mlx_tests.MLXTestCase):
) )
out_pt = torch.transpose(out_pt, 2, 1) out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape) self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol)) self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=atol))
for dtype in ("float32",): for dtype in ("float32",):
@ -141,7 +141,7 @@ class TestConv(mlx_tests.MLXTestCase):
out_pt = torch.conv1d(in_pt, wt_pt) out_pt = torch.conv1d(in_pt, wt_pt)
out_pt = torch.transpose(out_pt, 2, 1) out_pt = torch.transpose(out_pt, 2, 1)
self.assertListEqual(list(out_pt.shape), out_mx.shape) self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5)) self.assertTrue(np.allclose(out_pt.numpy(), out_mx, atol=1e-5))
@unittest.skipIf(not has_torch, "requires Torch") @unittest.skipIf(not has_torch, "requires Torch")
@ -228,12 +228,12 @@ class TestConv(mlx_tests.MLXTestCase):
mx_grad_in, mx_grad_wt = outs_mx mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape) self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape) self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape) self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape) self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",): for dtype in ("float32",):
@ -309,7 +309,7 @@ class TestConv(mlx_tests.MLXTestCase):
) )
out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True)
self.assertListEqual(list(out_pt.shape), list(out_mx.shape)) self.assertEqual(out_pt.shape, out_mx.shape)
self.assertTrue(np.allclose(out_pt, out_mx, atol=atol)) self.assertTrue(np.allclose(out_pt, out_mx, atol=atol))
for dtype in ("float32",): for dtype in ("float32",):
@ -419,12 +419,12 @@ class TestConv(mlx_tests.MLXTestCase):
mx_grad_in, mx_grad_wt = outs_mx mx_grad_in, mx_grad_wt = outs_mx
self.assertListEqual(list(pt_grad_in.shape), mx_grad_in.shape) self.assertEqual(pt_grad_in.shape, mx_grad_in.shape)
self.assertListEqual(list(in_mx.shape), mx_grad_in.shape) self.assertEqual(in_mx.shape, mx_grad_in.shape)
self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol)) self.assertTrue(np.allclose(pt_grad_in, mx_grad_in, atol=atol))
self.assertListEqual(list(pt_grad_wt.shape), mx_grad_wt.shape) self.assertEqual(pt_grad_wt.shape, mx_grad_wt.shape)
self.assertListEqual(list(wt_mx.shape), mx_grad_wt.shape) self.assertEqual(wt_mx.shape, mx_grad_wt.shape)
self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol)) self.assertTrue(np.allclose(pt_grad_wt, mx_grad_wt, atol=atol))
for dtype in ("float32",): for dtype in ("float32",):

View File

@ -13,7 +13,7 @@ class TestInit(mlx_tests.MLXTestCase):
for dtype in [mx.float32, mx.float16]: for dtype in [mx.float32, mx.float16]:
initializer = init.constant(value, dtype) initializer = init.constant(value, dtype)
for shape in [[3], [3, 3], [3, 3, 3]]: for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.array(mx.zeros(shape))) result = initializer(mx.array(mx.zeros(shape)))
with self.subTest(shape=shape): with self.subTest(shape=shape):
self.assertEqual(result.shape, shape) self.assertEqual(result.shape, shape)
@ -24,7 +24,7 @@ class TestInit(mlx_tests.MLXTestCase):
std = 1.0 std = 1.0
for dtype in [mx.float32, mx.float16]: for dtype in [mx.float32, mx.float16]:
initializer = init.normal(mean, std, dtype=dtype) initializer = init.normal(mean, std, dtype=dtype)
for shape in [[3], [3, 3], [3, 3, 3]]: for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape))) result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape): with self.subTest(shape=shape):
self.assertEqual(result.shape, shape) self.assertEqual(result.shape, shape)
@ -36,7 +36,7 @@ class TestInit(mlx_tests.MLXTestCase):
for dtype in [mx.float32, mx.float16]: for dtype in [mx.float32, mx.float16]:
initializer = init.uniform(low, high, dtype) initializer = init.uniform(low, high, dtype)
for shape in [[3], [3, 3], [3, 3, 3]]: for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape))) result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape): with self.subTest(shape=shape):
self.assertEqual(result.shape, shape) self.assertEqual(result.shape, shape)
@ -46,7 +46,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):
for dtype in [mx.float32, mx.float16]: for dtype in [mx.float32, mx.float16]:
initializer = init.identity(dtype) initializer = init.identity(dtype)
for shape in [[3], [3, 3], [3, 3, 3]]: for shape in [(3,), (3, 3), (3, 3, 3)]:
result = initializer(mx.zeros((3, 3))) result = initializer(mx.zeros((3, 3)))
self.assertTrue(mx.array_equal(result, mx.eye(3))) self.assertTrue(mx.array_equal(result, mx.eye(3)))
self.assertEqual(result.dtype, dtype) self.assertEqual(result.dtype, dtype)
@ -56,7 +56,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_glorot_normal(self): def test_glorot_normal(self):
for dtype in [mx.float32, mx.float16]: for dtype in [mx.float32, mx.float16]:
initializer = init.glorot_normal(dtype) initializer = init.glorot_normal(dtype)
for shape in [[3, 3], [3, 3, 3]]: for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape))) result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape): with self.subTest(shape=shape):
self.assertEqual(result.shape, shape) self.assertEqual(result.shape, shape)
@ -65,7 +65,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_glorot_uniform(self): def test_glorot_uniform(self):
for dtype in [mx.float32, mx.float16]: for dtype in [mx.float32, mx.float16]:
initializer = init.glorot_uniform(dtype) initializer = init.glorot_uniform(dtype)
for shape in [[3, 3], [3, 3, 3]]: for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape))) result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape): with self.subTest(shape=shape):
self.assertEqual(result.shape, shape) self.assertEqual(result.shape, shape)
@ -74,7 +74,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_he_normal(self): def test_he_normal(self):
for dtype in [mx.float32, mx.float16]: for dtype in [mx.float32, mx.float16]:
initializer = init.he_normal(dtype) initializer = init.he_normal(dtype)
for shape in [[3, 3], [3, 3, 3]]: for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape))) result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape): with self.subTest(shape=shape):
self.assertEqual(result.shape, shape) self.assertEqual(result.shape, shape)
@ -83,7 +83,7 @@ class TestInit(mlx_tests.MLXTestCase):
def test_he_uniform(self): def test_he_uniform(self):
for dtype in [mx.float32, mx.float16]: for dtype in [mx.float32, mx.float16]:
initializer = init.he_uniform(dtype) initializer = init.he_uniform(dtype)
for shape in [[3, 3], [3, 3, 3]]: for shape in [(3, 3), (3, 3, 3)]:
result = initializer(mx.array(np.empty(shape))) result = initializer(mx.array(np.empty(shape)))
with self.subTest(shape=shape): with self.subTest(shape=shape):
self.assertEqual(result.shape, shape) self.assertEqual(result.shape, shape)

View File

@ -136,20 +136,20 @@ class TestLayers(mlx_tests.MLXTestCase):
inputs = mx.zeros((10, 4)) inputs = mx.zeros((10, 4))
layer = nn.Identity() layer = nn.Identity()
outputs = layer(inputs) outputs = layer(inputs)
self.assertEqual(tuple(inputs.shape), tuple(outputs.shape)) self.assertEqual(inputs.shape, outputs.shape)
def test_linear(self): def test_linear(self):
inputs = mx.zeros((10, 4)) inputs = mx.zeros((10, 4))
layer = nn.Linear(input_dims=4, output_dims=8) layer = nn.Linear(input_dims=4, output_dims=8)
outputs = layer(inputs) outputs = layer(inputs)
self.assertEqual(tuple(outputs.shape), (10, 8)) self.assertEqual(outputs.shape, (10, 8))
def test_bilinear(self): def test_bilinear(self):
inputs1 = mx.zeros((10, 2)) inputs1 = mx.zeros((10, 2))
inputs2 = mx.zeros((10, 4)) inputs2 = mx.zeros((10, 4))
layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6) layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)
outputs = layer(inputs1, inputs2) outputs = layer(inputs1, inputs2)
self.assertEqual(tuple(outputs.shape), (10, 6)) self.assertEqual(outputs.shape, (10, 6))
def test_group_norm(self): def test_group_norm(self):
x = mx.arange(100, dtype=mx.float32) x = mx.arange(100, dtype=mx.float32)
@ -573,12 +573,12 @@ class TestLayers(mlx_tests.MLXTestCase):
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks) c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks)
c.weight = mx.ones_like(c.weight) c.weight = mx.ones_like(c.weight)
y = c(x) y = c(x)
self.assertEqual(y.shape, [N, L - ks + 1, C_out]) self.assertEqual(y.shape, (N, L - ks + 1, C_out))
self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32))) self.assertTrue(mx.allclose(y, mx.full(y.shape, ks * C_in, mx.float32)))
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2) c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, stride=2)
y = c(x) y = c(x)
self.assertEqual(y.shape, [N, (L - ks + 1) // 2, C_out]) self.assertEqual(y.shape, (N, (L - ks + 1) // 2, C_out))
self.assertTrue("bias" in c.parameters()) self.assertTrue("bias" in c.parameters())
c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False) c = nn.Conv1d(in_channels=C_in, out_channels=C_out, kernel_size=ks, bias=False)
@ -588,7 +588,7 @@ class TestLayers(mlx_tests.MLXTestCase):
x = mx.ones((4, 8, 8, 3)) x = mx.ones((4, 8, 8, 3))
c = nn.Conv2d(3, 1, 8) c = nn.Conv2d(3, 1, 8)
y = c(x) y = c(x)
self.assertEqual(y.shape, [4, 1, 1, 1]) self.assertEqual(y.shape, (4, 1, 1, 1))
c.weight = mx.ones_like(c.weight) / 8 / 8 / 3 c.weight = mx.ones_like(c.weight) / 8 / 8 / 3
y = c(x) y = c(x)
self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3)))) self.assertTrue(np.allclose(y[:, 0, 0, 0], x.mean(axis=(1, 2, 3))))
@ -596,13 +596,13 @@ class TestLayers(mlx_tests.MLXTestCase):
# 3x3 conv no padding stride 1 # 3x3 conv no padding stride 1
c = nn.Conv2d(3, 8, 3) c = nn.Conv2d(3, 8, 3)
y = c(x) y = c(x)
self.assertEqual(y.shape, [4, 6, 6, 8]) self.assertEqual(y.shape, (4, 6, 6, 8))
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
# 3x3 conv padding 1 stride 1 # 3x3 conv padding 1 stride 1
c = nn.Conv2d(3, 8, 3, padding=1) c = nn.Conv2d(3, 8, 3, padding=1)
y = c(x) y = c(x)
self.assertEqual(y.shape, [4, 8, 8, 8]) self.assertEqual(y.shape, (4, 8, 8, 8))
self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4) self.assertLess(mx.abs(y[:, 1:7, 1:7] - c.weight.sum((1, 2, 3))).max(), 1e-4)
self.assertLess( self.assertLess(
mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(), mx.abs(y[:, 0, 0] - c.weight[:, 1:, 1:].sum(axis=(1, 2, 3))).max(),
@ -624,14 +624,14 @@ class TestLayers(mlx_tests.MLXTestCase):
# 3x3 conv no padding stride 2 # 3x3 conv no padding stride 2
c = nn.Conv2d(3, 8, 3, padding=0, stride=2) c = nn.Conv2d(3, 8, 3, padding=0, stride=2)
y = c(x) y = c(x)
self.assertEqual(y.shape, [4, 3, 3, 8]) self.assertEqual(y.shape, (4, 3, 3, 8))
self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4) self.assertLess(mx.abs(y - c.weight.sum((1, 2, 3))).max(), 1e-4)
def test_sequential(self): def test_sequential(self):
x = mx.ones((10, 2)) x = mx.ones((10, 2))
m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1)) m = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 1))
y = m(x) y = m(x)
self.assertEqual(y.shape, [10, 1]) self.assertEqual(y.shape, (10, 1))
params = m.parameters() params = m.parameters()
self.assertTrue("layers" in params) self.assertTrue("layers" in params)
self.assertEqual(len(params["layers"]), 3) self.assertEqual(len(params["layers"]), 3)
@ -667,7 +667,7 @@ class TestLayers(mlx_tests.MLXTestCase):
x = mx.arange(10) x = mx.arange(10)
y = m(x) y = m(x)
self.assertEqual(y.shape, [10, 16]) self.assertEqual(y.shape, (10, 16))
similarities = y @ y.T similarities = y @ y.T
self.assertLess( self.assertLess(
mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5 mx.abs(similarities[mx.arange(10), mx.arange(10)] - 1).max(), 1e-5
@ -686,19 +686,19 @@ class TestLayers(mlx_tests.MLXTestCase):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.relu(x) y = nn.relu(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0]))) self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0])))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_leaky_relu(self): def test_leaky_relu(self):
x = mx.array([1.0, -1.0, 0.0]) x = mx.array([1.0, -1.0, 0.0])
y = nn.leaky_relu(x) y = nn.leaky_relu(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0]))) self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.01, 0.0])))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
y = nn.LeakyReLU(negative_slope=0.1)(x) y = nn.LeakyReLU(negative_slope=0.1)(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0]))) self.assertTrue(mx.array_equal(y, mx.array([1.0, -0.1, 0.0])))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_elu(self): def test_elu(self):
@ -707,21 +707,21 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([1.0, -0.6321, 0.0]) expected_y = mx.array([1.0, -0.6321, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
y = nn.ELU(alpha=1.1)(x) y = nn.ELU(alpha=1.1)(x)
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([1.0, -0.6953, 0.0]) expected_y = mx.array([1.0, -0.6953, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_relu6(self): def test_relu6(self):
x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0]) x = mx.array([1.0, -1.0, 0.0, 7.0, -7.0])
y = nn.relu6(x) y = nn.relu6(x)
self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0]))) self.assertTrue(mx.array_equal(y, mx.array([1.0, 0.0, 0.0, 6.0, 0.0])))
self.assertEqual(y.shape, [5]) self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_softmax(self): def test_softmax(self):
@ -730,7 +730,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([0.6652, 0.0900, 0.2447]) expected_y = mx.array([0.6652, 0.0900, 0.2447])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_softplus(self): def test_softplus(self):
@ -739,7 +739,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([1.3133, 0.3133, 0.6931]) expected_y = mx.array([1.3133, 0.3133, 0.6931])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_softsign(self): def test_softsign(self):
@ -748,7 +748,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([0.5, -0.5, 0.0]) expected_y = mx.array([0.5, -0.5, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_softshrink(self): def test_softshrink(self):
@ -757,13 +757,13 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([0.5, -0.5, 0.0]) expected_y = mx.array([0.5, -0.5, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
y = nn.Softshrink(lambd=0.7)(x) y = nn.Softshrink(lambd=0.7)(x)
expected_y = mx.array([0.3, -0.3, 0.0]) expected_y = mx.array([0.3, -0.3, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_celu(self): def test_celu(self):
@ -772,13 +772,13 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([1.0, -0.6321, 0.0]) expected_y = mx.array([1.0, -0.6321, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
y = nn.CELU(alpha=1.1)(x) y = nn.CELU(alpha=1.1)(x)
expected_y = mx.array([1.0, -0.6568, 0.0]) expected_y = mx.array([1.0, -0.6568, 0.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_log_softmax(self): def test_log_softmax(self):
@ -787,7 +787,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([-2.4076, -1.4076, -0.4076]) expected_y = mx.array([-2.4076, -1.4076, -0.4076])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_log_sigmoid(self): def test_log_sigmoid(self):
@ -796,7 +796,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([-0.3133, -1.3133, -0.6931]) expected_y = mx.array([-0.3133, -1.3133, -0.6931])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [3]) self.assertEqual(y.shape, (3,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_prelu(self): def test_prelu(self):
@ -817,7 +817,7 @@ class TestLayers(mlx_tests.MLXTestCase):
epsilon = 1e-4 epsilon = 1e-4
expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0]) expected_y = mx.array([0.0, -0.375, 0.0, 1.125, 3.0])
self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon)) self.assertTrue(mx.all(mx.abs(y - expected_y) < epsilon))
self.assertEqual(y.shape, [5]) self.assertEqual(y.shape, (5,))
self.assertEqual(y.dtype, mx.float32) self.assertEqual(y.dtype, mx.float32)
def test_glu(self): def test_glu(self):

View File

@ -12,12 +12,12 @@ import numpy as np
class TestOps(mlx_tests.MLXTestCase): class TestOps(mlx_tests.MLXTestCase):
def test_full_ones_zeros(self): def test_full_ones_zeros(self):
x = mx.full(2, 3.0) x = mx.full(2, 3.0)
self.assertEqual(x.shape, [2]) self.assertEqual(x.shape, (2,))
self.assertEqual(x.tolist(), [3.0, 3.0]) self.assertEqual(x.tolist(), [3.0, 3.0])
x = mx.full((2, 3), 2.0) x = mx.full((2, 3), 2.0)
self.assertEqual(x.dtype, mx.float32) self.assertEqual(x.dtype, mx.float32)
self.assertEqual(x.shape, [2, 3]) self.assertEqual(x.shape, (2, 3))
self.assertEqual(x.tolist(), [[2, 2, 2], [2, 2, 2]]) self.assertEqual(x.tolist(), [[2, 2, 2], [2, 2, 2]])
x = mx.full([3, 2], mx.array([False, True])) x = mx.full([3, 2], mx.array([False, True]))
@ -28,11 +28,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(x.tolist(), [[2, 3], [2, 3], [2, 3]]) self.assertEqual(x.tolist(), [[2, 3], [2, 3], [2, 3]])
x = mx.zeros(2) x = mx.zeros(2)
self.assertEqual(x.shape, [2]) self.assertEqual(x.shape, (2,))
self.assertEqual(x.tolist(), [0.0, 0.0]) self.assertEqual(x.tolist(), [0.0, 0.0])
x = mx.ones(2) x = mx.ones(2)
self.assertEqual(x.shape, [2]) self.assertEqual(x.shape, (2,))
self.assertEqual(x.tolist(), [1.0, 1.0]) self.assertEqual(x.tolist(), [1.0, 1.0])
for t in [mx.bool_, mx.int32, mx.float32]: for t in [mx.bool_, mx.int32, mx.float32]:
@ -530,10 +530,10 @@ class TestOps(mlx_tests.MLXTestCase):
def test_move_swap_axes(self): def test_move_swap_axes(self):
x = mx.zeros((2, 3, 4)) x = mx.zeros((2, 3, 4))
self.assertEqual(mx.moveaxis(x, 0, 2).shape, [3, 4, 2]) self.assertEqual(mx.moveaxis(x, 0, 2).shape, (3, 4, 2))
self.assertEqual(x.moveaxis(0, 2).shape, [3, 4, 2]) self.assertEqual(x.moveaxis(0, 2).shape, (3, 4, 2))
self.assertEqual(mx.swapaxes(x, 0, 2).shape, [4, 3, 2]) self.assertEqual(mx.swapaxes(x, 0, 2).shape, (4, 3, 2))
self.assertEqual(x.swapaxes(0, 2).shape, [4, 3, 2]) self.assertEqual(x.swapaxes(0, 2).shape, (4, 3, 2))
def test_sum(self): def test_sum(self):
x = mx.array( x = mx.array(
@ -545,7 +545,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.sum(x).item(), 9) self.assertEqual(mx.sum(x).item(), 9)
y = mx.sum(x, keepdims=True) y = mx.sum(x, keepdims=True)
self.assertEqual(y, mx.array(9)) self.assertEqual(y, mx.array(9))
self.assertEqual(y.shape, [1, 1]) self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.sum(x, axis=0).tolist(), [4, 5]) self.assertEqual(mx.sum(x, axis=0).tolist(), [4, 5])
self.assertEqual(mx.sum(x, axis=1).tolist(), [3, 6]) self.assertEqual(mx.sum(x, axis=1).tolist(), [3, 6])
@ -585,7 +585,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.prod(x).item(), 18) self.assertEqual(mx.prod(x).item(), 18)
y = mx.prod(x, keepdims=True) y = mx.prod(x, keepdims=True)
self.assertEqual(y, mx.array(18)) self.assertEqual(y, mx.array(18))
self.assertEqual(y.shape, [1, 1]) self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.prod(x, axis=0).tolist(), [3, 6]) self.assertEqual(mx.prod(x, axis=0).tolist(), [3, 6])
self.assertEqual(mx.prod(x, axis=1).tolist(), [2, 9]) self.assertEqual(mx.prod(x, axis=1).tolist(), [2, 9])
@ -600,11 +600,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.min(x).item(), 1) self.assertEqual(mx.min(x).item(), 1)
self.assertEqual(mx.max(x).item(), 4) self.assertEqual(mx.max(x).item(), 4)
y = mx.min(x, keepdims=True) y = mx.min(x, keepdims=True)
self.assertEqual(y.shape, [1, 1]) self.assertEqual(y.shape, (1, 1))
self.assertEqual(y, mx.array(1)) self.assertEqual(y, mx.array(1))
y = mx.max(x, keepdims=True) y = mx.max(x, keepdims=True)
self.assertEqual(y.shape, [1, 1]) self.assertEqual(y.shape, (1, 1))
self.assertEqual(y, mx.array(4)) self.assertEqual(y, mx.array(4))
self.assertEqual(mx.min(x, axis=0).tolist(), [1, 2]) self.assertEqual(mx.min(x, axis=0).tolist(), [1, 2])
@ -670,7 +670,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.mean(x).item(), 2.5) self.assertEqual(mx.mean(x).item(), 2.5)
y = mx.mean(x, keepdims=True) y = mx.mean(x, keepdims=True)
self.assertEqual(y, mx.array(2.5)) self.assertEqual(y, mx.array(2.5))
self.assertEqual(y.shape, [1, 1]) self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3]) self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3])
self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5]) self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5])
@ -685,7 +685,7 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(mx.var(x).item(), 1.25) self.assertEqual(mx.var(x).item(), 1.25)
y = mx.var(x, keepdims=True) y = mx.var(x, keepdims=True)
self.assertEqual(y, mx.array(1.25)) self.assertEqual(y, mx.array(1.25))
self.assertEqual(y.shape, [1, 1]) self.assertEqual(y.shape, (1, 1))
self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0]) self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0])
self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25]) self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25])
@ -888,7 +888,7 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.array([[True, False], [True, True]]) a = mx.array([[True, False], [True, True]])
self.assertFalse(mx.all(a).item()) self.assertFalse(mx.all(a).item())
self.assertEqual(mx.all(a, keepdims=True).shape, [1, 1]) self.assertEqual(mx.all(a, keepdims=True).shape, (1, 1))
self.assertFalse(mx.all(a, axis=[0, 1]).item()) self.assertFalse(mx.all(a, axis=[0, 1]).item())
self.assertEqual(mx.all(a, axis=[0]).tolist(), [True, False]) self.assertEqual(mx.all(a, axis=[0]).tolist(), [True, False])
self.assertEqual(mx.all(a, axis=[1]).tolist(), [False, True]) self.assertEqual(mx.all(a, axis=[1]).tolist(), [False, True])
@ -899,7 +899,7 @@ class TestOps(mlx_tests.MLXTestCase):
a = mx.array([[True, False], [False, False]]) a = mx.array([[True, False], [False, False]])
self.assertTrue(mx.any(a).item()) self.assertTrue(mx.any(a).item())
self.assertEqual(mx.any(a, keepdims=True).shape, [1, 1]) self.assertEqual(mx.any(a, keepdims=True).shape, (1, 1))
self.assertTrue(mx.any(a, axis=[0, 1]).item()) self.assertTrue(mx.any(a, axis=[0, 1]).item())
self.assertEqual(mx.any(a, axis=[0]).tolist(), [True, False]) self.assertEqual(mx.any(a, axis=[0]).tolist(), [True, False])
self.assertEqual(mx.any(a, axis=[1]).tolist(), [True, False]) self.assertEqual(mx.any(a, axis=[1]).tolist(), [True, False])
@ -956,22 +956,22 @@ class TestOps(mlx_tests.MLXTestCase):
a_npy_taken = np.take(a_npy, idx_npy) a_npy_taken = np.take(a_npy, idx_npy)
a_mlx_taken = mx.take(a_mlx, idx_mlx) a_mlx_taken = mx.take(a_mlx, idx_mlx)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
a_npy_taken = np.take(a_npy, idx_npy, axis=0) a_npy_taken = np.take(a_npy, idx_npy, axis=0)
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=0) a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=0)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
a_npy_taken = np.take(a_npy, idx_npy, axis=1) a_npy_taken = np.take(a_npy, idx_npy, axis=1)
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=1) a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=1)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
a_npy_taken = np.take(a_npy, idx_npy, axis=2) a_npy_taken = np.take(a_npy, idx_npy, axis=2)
a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=2) a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=2)
self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) self.assertEqual(a_npy_taken.shape, a_mlx_taken.shape)
self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist())
def test_take_along_axis(self): def test_take_along_axis(self):
@ -1400,13 +1400,13 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6))
a = mx.zeros((1, 1, 1)) a = mx.zeros((1, 1, 1))
self.assertEqual(mx.pad(a, 1).shape, [3, 3, 3]) self.assertEqual(mx.pad(a, 1).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, (1,)).shape, [3, 3, 3]) self.assertEqual(mx.pad(a, (1,)).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, [1]).shape, [3, 3, 3]) self.assertEqual(mx.pad(a, [1]).shape, (3, 3, 3))
self.assertEqual(mx.pad(a, (1, 2)).shape, [4, 4, 4]) self.assertEqual(mx.pad(a, (1, 2)).shape, (4, 4, 4))
self.assertEqual(mx.pad(a, [(1, 2)]).shape, [4, 4, 4]) self.assertEqual(mx.pad(a, [(1, 2)]).shape, (4, 4, 4))
self.assertEqual(mx.pad(a, ((1, 2),)).shape, [4, 4, 4]) self.assertEqual(mx.pad(a, ((1, 2),)).shape, (4, 4, 4))
self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, [4, 4, 5]) self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, (4, 4, 5))
# Test grads # Test grads
a_fwd = mx.array(np.random.rand(16, 16).astype(np.float32)) a_fwd = mx.array(np.random.rand(16, 16).astype(np.float32))
@ -1490,19 +1490,19 @@ class TestOps(mlx_tests.MLXTestCase):
def test_squeeze_expand(self): def test_squeeze_expand(self):
a = mx.zeros((2, 1, 2, 1)) a = mx.zeros((2, 1, 2, 1))
self.assertEqual(mx.squeeze(a).shape, [2, 2]) self.assertEqual(mx.squeeze(a).shape, (2, 2))
self.assertEqual(mx.squeeze(a, 1).shape, [2, 2, 1]) self.assertEqual(mx.squeeze(a, 1).shape, (2, 2, 1))
self.assertEqual(mx.squeeze(a, [1, 3]).shape, [2, 2]) self.assertEqual(mx.squeeze(a, [1, 3]).shape, (2, 2))
self.assertEqual(a.squeeze().shape, [2, 2]) self.assertEqual(a.squeeze().shape, (2, 2))
self.assertEqual(a.squeeze(1).shape, [2, 2, 1]) self.assertEqual(a.squeeze(1).shape, (2, 2, 1))
self.assertEqual(a.squeeze([1, 3]).shape, [2, 2]) self.assertEqual(a.squeeze([1, 3]).shape, (2, 2))
a = mx.zeros((2, 2)) a = mx.zeros((2, 2))
self.assertEqual(mx.squeeze(a).shape, [2, 2]) self.assertEqual(mx.squeeze(a).shape, (2, 2))
self.assertEqual(mx.expand_dims(a, 0).shape, [1, 2, 2]) self.assertEqual(mx.expand_dims(a, 0).shape, (1, 2, 2))
self.assertEqual(mx.expand_dims(a, (0, 1)).shape, [1, 1, 2, 2]) self.assertEqual(mx.expand_dims(a, (0, 1)).shape, (1, 1, 2, 2))
self.assertEqual(mx.expand_dims(a, [0, -1]).shape, [1, 2, 2, 1]) self.assertEqual(mx.expand_dims(a, [0, -1]).shape, (1, 2, 2, 1))
def test_sort(self): def test_sort(self):
shape = (3, 4, 5) shape = (3, 4, 5)
@ -1603,12 +1603,12 @@ class TestOps(mlx_tests.MLXTestCase):
def test_flatten(self): def test_flatten(self):
x = mx.zeros([2, 3, 4]) x = mx.zeros([2, 3, 4])
self.assertEqual(mx.flatten(x).shape, [2 * 3 * 4]) self.assertEqual(mx.flatten(x).shape, (2 * 3 * 4,))
self.assertEqual(mx.flatten(x, start_axis=1).shape, [2, 3 * 4]) self.assertEqual(mx.flatten(x, start_axis=1).shape, (2, 3 * 4))
self.assertEqual(mx.flatten(x, end_axis=1).shape, [2 * 3, 4]) self.assertEqual(mx.flatten(x, end_axis=1).shape, (2 * 3, 4))
self.assertEqual(x.flatten().shape, [2 * 3 * 4]) self.assertEqual(x.flatten().shape, (2 * 3 * 4,))
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4]) self.assertEqual(x.flatten(start_axis=1).shape, (2, 3 * 4))
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4]) self.assertEqual(x.flatten(end_axis=1).shape, (2 * 3, 4))
def test_clip(self): def test_clip(self):
a = np.array([1, 4, 3, 8, 5], np.int32) a = np.array([1, 4, 3, 8, 5], np.int32)

View File

@ -38,19 +38,19 @@ class TestRandom(mlx_tests.MLXTestCase):
self.assertTrue(mx.array_equal(k2, r2)) self.assertTrue(mx.array_equal(k2, r2))
keys = mx.random.split(key, 10) keys = mx.random.split(key, 10)
self.assertEqual(keys.shape, [10, 2]) self.assertEqual(keys.shape, (10, 2))
def test_uniform(self): def test_uniform(self):
key = mx.random.key(0) key = mx.random.key(0)
a = mx.random.uniform(key=key) a = mx.random.uniform(key=key)
self.assertEqual(a.shape, []) self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.float32) self.assertEqual(a.dtype, mx.float32)
b = mx.random.uniform(key=key) b = mx.random.uniform(key=key)
self.assertEqual(a.item(), b.item()) self.assertEqual(a.item(), b.item())
a = mx.random.uniform(shape=(2, 3)) a = mx.random.uniform(shape=(2, 3))
self.assertEqual(a.shape, [2, 3]) self.assertEqual(a.shape, (2, 3))
a = mx.random.uniform(shape=(1000,), low=-1, high=5) a = mx.random.uniform(shape=(1000,), low=-1, high=5)
self.assertTrue(mx.all((a > -1) < 5).item()) self.assertTrue(mx.all((a > -1) < 5).item())
@ -66,14 +66,14 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_normal(self): def test_normal(self):
key = mx.random.key(0) key = mx.random.key(0)
a = mx.random.normal(key=key) a = mx.random.normal(key=key)
self.assertEqual(a.shape, []) self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.float32) self.assertEqual(a.dtype, mx.float32)
b = mx.random.normal(key=key) b = mx.random.normal(key=key)
self.assertEqual(a.item(), b.item()) self.assertEqual(a.item(), b.item())
a = mx.random.normal(shape=(2, 3)) a = mx.random.normal(shape=(2, 3))
self.assertEqual(a.shape, [2, 3]) self.assertEqual(a.shape, (2, 3))
## Generate in float16 or bfloat16 ## Generate in float16 or bfloat16
for t in [mx.float16, mx.bfloat16]: for t in [mx.float16, mx.bfloat16]:
@ -84,10 +84,10 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_randint(self): def test_randint(self):
a = mx.random.randint(0, 1, []) a = mx.random.randint(0, 1, [])
self.assertEqual(a.shape, []) self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.int32) self.assertEqual(a.dtype, mx.int32)
shape = [88] shape = (88,)
low = mx.array(3) low = mx.array(3)
high = mx.array(15) high = mx.array(15)
@ -100,7 +100,7 @@ class TestRandom(mlx_tests.MLXTestCase):
b = mx.random.randint(low, high, shape, key=key) b = mx.random.randint(low, high, shape, key=key)
self.assertListEqual(a.tolist(), b.tolist()) self.assertListEqual(a.tolist(), b.tolist())
shape = [3, 4] shape = (3, 4)
low = mx.reshape(mx.array([0] * 3), [3, 1]) low = mx.reshape(mx.array([0] * 3), [3, 1])
high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4]) high = mx.reshape(mx.array([12, 13, 14, 15]), [1, 4])
@ -119,20 +119,20 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_bernoulli(self): def test_bernoulli(self):
a = mx.random.bernoulli() a = mx.random.bernoulli()
self.assertEqual(a.shape, []) self.assertEqual(a.shape, ())
self.assertEqual(a.dtype, mx.bool_) self.assertEqual(a.dtype, mx.bool_)
a = mx.random.bernoulli(mx.array(0.5), [5]) a = mx.random.bernoulli(mx.array(0.5), [5])
self.assertEqual(a.shape, [5]) self.assertEqual(a.shape, (5,))
a = mx.random.bernoulli(mx.array([2.0, -2.0])) a = mx.random.bernoulli(mx.array([2.0, -2.0]))
self.assertEqual(a.tolist(), [True, False]) self.assertEqual(a.tolist(), [True, False])
self.assertEqual(a.shape, [2]) self.assertEqual(a.shape, (2,))
p = mx.array([0.1, 0.2, 0.3]) p = mx.array([0.1, 0.2, 0.3])
mx.reshape(p, [1, 3]) mx.reshape(p, [1, 3])
x = mx.random.bernoulli(p, [4, 3]) x = mx.random.bernoulli(p, [4, 3])
self.assertEqual(x.shape, [4, 3]) self.assertEqual(x.shape, (4, 3))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.random.bernoulli(p, [2]) # Bad shape mx.random.bernoulli(p, [2]) # Bad shape
@ -153,14 +153,14 @@ class TestRandom(mlx_tests.MLXTestCase):
upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1]) upper = mx.reshape(mx.array([0.0, 1.0, 2.0]), [3, 1])
a = mx.random.truncated_normal(lower, upper) a = mx.random.truncated_normal(lower, upper)
self.assertEqual(a.shape, [3, 2]) self.assertEqual(a.shape, (3, 2))
self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item()) self.assertTrue(mx.all(lower <= a).item() and mx.all(a <= upper).item())
a = mx.random.truncated_normal(2.0, -2.0) a = mx.random.truncated_normal(2.0, -2.0)
self.assertTrue(mx.all(a == 2.0).item()) self.assertTrue(mx.all(a == 2.0).item())
a = mx.random.truncated_normal(-3.0, 3.0, [542, 399]) a = mx.random.truncated_normal(-3.0, 3.0, [542, 399])
self.assertEqual(a.shape, [542, 399]) self.assertEqual(a.shape, (542, 399))
lower = mx.array([-2.0, -1.0]) lower = mx.array([-2.0, -1.0])
higher = mx.array([1.0, 2.0, 3.0]) higher = mx.array([1.0, 2.0, 3.0])
@ -174,7 +174,7 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_gumbel(self): def test_gumbel(self):
samples = mx.random.gumbel(shape=(100, 100)) samples = mx.random.gumbel(shape=(100, 100))
self.assertEqual(samples.shape, [100, 100]) self.assertEqual(samples.shape, (100, 100))
self.assertEqual(samples.dtype, mx.float32) self.assertEqual(samples.dtype, mx.float32)
mean = 0.5772 mean = 0.5772
# Std deviation of the sample mean is small (<0.02), # Std deviation of the sample mean is small (<0.02),
@ -187,23 +187,23 @@ class TestRandom(mlx_tests.MLXTestCase):
def test_categorical(self): def test_categorical(self):
logits = mx.zeros((10, 20)) logits = mx.zeros((10, 20))
self.assertEqual(mx.random.categorical(logits, -1).shape, [10]) self.assertEqual(mx.random.categorical(logits, -1).shape, (10,))
self.assertEqual(mx.random.categorical(logits, 0).shape, [20]) self.assertEqual(mx.random.categorical(logits, 0).shape, (20,))
self.assertEqual(mx.random.categorical(logits, 1).shape, [10]) self.assertEqual(mx.random.categorical(logits, 1).shape, (10,))
out = mx.random.categorical(logits) out = mx.random.categorical(logits)
self.assertEqual(out.shape, [10]) self.assertEqual(out.shape, (10,))
self.assertEqual(out.dtype, mx.uint32) self.assertEqual(out.dtype, mx.uint32)
self.assertTrue(mx.max(out).item() < 20) self.assertTrue(mx.max(out).item() < 20)
out = mx.random.categorical(logits, 0, [5, 20]) out = mx.random.categorical(logits, 0, [5, 20])
self.assertEqual(out.shape, [5, 20]) self.assertEqual(out.shape, (5, 20))
self.assertTrue(mx.max(out).item() < 10) self.assertTrue(mx.max(out).item() < 10)
out = mx.random.categorical(logits, 1, num_samples=7) out = mx.random.categorical(logits, 1, num_samples=7)
self.assertEqual(out.shape, [10, 7]) self.assertEqual(out.shape, (10, 7))
out = mx.random.categorical(logits, 0, num_samples=7) out = mx.random.categorical(logits, 0, num_samples=7)
self.assertEqual(out.shape, [20, 7]) self.assertEqual(out.shape, (20, 7))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
mx.random.categorical(logits, shape=[10, 5], num_samples=5) mx.random.categorical(logits, shape=[10, 5], num_samples=5)