diff --git a/python/tests/mlx_tests.py b/python/tests/mlx_tests.py index d9a485885..dfbc835da 100644 --- a/python/tests/mlx_tests.py +++ b/python/tests/mlx_tests.py @@ -2,6 +2,7 @@ import os import unittest +from typing import Any, Callable, List, Tuple import mlx.core as mx import numpy as np @@ -18,13 +19,33 @@ class MLXTestCase(unittest.TestCase): def tearDown(self): mx.set_default_device(self.default) + def assertCmpNumpy( + self, + shape: List[Tuple[int] | Any], + mx_fn: Callable[..., mx.array], + np_fn: Callable[..., np.array], + atol=1e-2, + rtol=1e-2, + dtype=mx.float32, + **kwargs, + ): + assert dtype != mx.bfloat16, "numpy does not support bfloat16" + args = [ + mx.random.normal(s, dtype=dtype) if isinstance(s, Tuple) else s + for s in shape + ] + mx_res = mx_fn(*args, **kwargs) + np_res = np_fn( + *[np.array(a) if isinstance(a, mx.array) else a for a in args], **kwargs + ) + return self.assertEqualArray(mx_res, mx.array(np_res), atol=atol, rtol=rtol) + def assertEqualArray( self, mx_res: mx.array, expected: mx.array, atol=1e-2, rtol=1e-2, - **kwargs, ): assert tuple(mx_res.shape) == tuple( expected.shape diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 65de09634..2e04477eb 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -324,24 +324,18 @@ class TestOps(mlx_tests.MLXTestCase): def test_tri(self): for shape in [[4], [4, 4], [2, 10]]: for diag in [-1, 0, 1, -2]: - self.assertEqualArray( - mx.tri(*shape, k=diag), mx.array(np.tri(*shape, k=diag)) - ) + self.assertCmpNumpy(shape, mx.tri, np.tri, k=diag) def test_tril(self): - mt = mx.random.normal((10, 10)) - nt = np.array(mt) for diag in [-1, 0, 1, -2]: - self.assertEqualArray(mx.tril(mt, diag), mx.array(np.tril(nt, diag))) + self.assertCmpNumpy([(10, 10)], mx.tril, np.tril, k=diag) with self.assertRaises(Exception): mx.tril(mx.zeros((1))) def test_triu(self): - mt = mx.random.normal((10, 10)) - nt = np.array(mt) for diag in [-1, 0, 1, -2]: - self.assertEqualArray(mx.triu(mt, diag), mx.array(np.triu(nt, diag))) + self.assertCmpNumpy([(10, 10)], mx.triu, np.triu, k=diag) with self.assertRaises(Exception): mx.triu(mx.zeros((1))) @@ -1260,20 +1254,17 @@ class TestOps(mlx_tests.MLXTestCase): self.assertTrue(mx.allclose(a_bwd[4:-2, 2:-4], df[0]).item()) def test_where(self): - a = mx.array([[1, 2], [3, 4]]) - out = mx.where(True, a, 1) - out_np = np.where(True, a, 1) - self.assertTrue(np.array_equal(out, out_np)) - - out = mx.where(True, 1, a) - out_np = np.where(True, 1, a) - self.assertTrue(np.array_equal(out, out_np)) - - condition = mx.array([[True, False], [False, True]]) - b = mx.array([5, 6]) - out = mx.where(condition, a, b) - out_np = np.where(condition, a, b) - self.assertTrue(np.array_equal(out, out_np)) + self.assertCmpNumpy([True, mx.array([[1, 2], [3, 4]]), 1], mx.where, np.where) + self.assertCmpNumpy([True, 1, mx.array([[1, 2], [3, 4]])], mx.where, np.where) + self.assertCmpNumpy( + [ + mx.array([[True, False], [False, True]]), + mx.array([[1, 2], [3, 4]]), + mx.array([5, 6]), + ], + mx.where, + np.where, + ) def test_as_strided(self): x_npy = np.random.randn(128).astype(np.float32) @@ -1408,24 +1399,13 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual((a + b)[0, 0].item(), 2) def test_eye(self): - eye_matrix = mx.eye(3) - np_eye_matrix = np.eye(3) - self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) - + self.assertCmpNumpy([3], mx.eye, np.eye) # Test for non-square matrix - eye_matrix = mx.eye(3, 4) - np_eye_matrix = np.eye(3, 4) - self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) - + self.assertCmpNumpy([3, 4], mx.eye, np.eye) # Test with positive k parameter - eye_matrix = mx.eye(3, 4, k=1) - np_eye_matrix = np.eye(3, 4, k=1) - self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) - + self.assertCmpNumpy([3, 4], mx.eye, np.eye, k=1) # Test with negative k parameter - eye_matrix = mx.eye(5, 6, k=-2) - np_eye_matrix = np.eye(5, 6, k=-2) - self.assertTrue(np.array_equal(eye_matrix, np_eye_matrix)) + self.assertCmpNumpy([5, 6], mx.eye, np.eye, k=-2) def test_stack(self): a = mx.ones((2,)) @@ -1518,50 +1498,47 @@ class TestOps(mlx_tests.MLXTestCase): def test_repeat(self): # Setup data for the tests - data = np.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) + data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) # Test repeat along axis 0 - repeat_axis_0 = mx.repeat(mx.array(data), 2, axis=0) - expected_axis_0 = np.repeat(data, 2, axis=0) - - self.assertEqualArray(repeat_axis_0, mx.array(expected_axis_0)) - + self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0) # Test repeat along axis 1 - repeat_axis_1 = mx.repeat(mx.array(data), 2, axis=1) - expected_axis_1 = np.repeat(data, 2, axis=1) - self.assertEqualArray(repeat_axis_1, mx.array(expected_axis_1)) - + self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=1) # Test repeat along the last axis (default) - repeat_axis_2 = mx.repeat(mx.array(data), 2) - expected_axis_2 = np.repeat(data, 2) - self.assertEqualArray(repeat_axis_2, mx.array(expected_axis_2)) - + self.assertCmpNumpy([data, 2], mx.repeat, np.repeat) # Test repeat with a 1D array along axis 0 - data_2 = mx.array([1, 3, 2]) - repeat_2 = mx.repeat(mx.array(data_2), 3, axis=0) - expected_2 = np.repeat(data_2, 3) - self.assertEqualArray(repeat_2, mx.array(expected_2)) - + self.assertCmpNumpy([mx.array([1, 3, 2]), 3], mx.repeat, np.repeat, axis=0) # Test repeat with a 2D array along axis 0 - data_3 = mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]]) - repeat_3 = mx.repeat(mx.array(data_3), 2, axis=0) - expected_3 = np.repeat(data_3, 2, axis=0) - self.assertEqualArray(repeat_3, mx.array(expected_3)) + self.assertCmpNumpy( + [mx.array([[1, 2, 3], [4, 5, 4], [0, 1, 2]]), 2], + mx.repeat, + np.repeat, + axis=0, + ) def test_tensordot(self): - x = mx.arange(60.0).reshape(3, 4, 5) - y = mx.arange(24.0).reshape(4, 3, 2) - z = mx.tensordot(x, y, dims=([1, 0], [0, 1])) - self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=([1, 0], [0, 1])))) - x = mx.random.normal((3, 4, 5)) - y = mx.random.normal((4, 5, 6)) - z = mx.tensordot(x, y, dims=2) - self.assertEqualArray(z, mx.array(np.tensordot(x, y, axes=2))) - x = mx.random.normal((3, 5, 4, 6)) - y = mx.random.normal((6, 4, 5, 3)) - z = mx.tensordot(x, y, dims=([2, 1, 3], [1, 2, 0])) - self.assertEqualArray( - z, mx.array(np.tensordot(x, y, axes=([2, 1, 3], [1, 2, 0]))) - ) + for dtype in [mx.float16, mx.float32]: + with self.subTest(dtype=dtype): + self.assertCmpNumpy( + [(3, 4, 5), (4, 3, 2)], + mx.tensordot, + lambda x, y, dims: np.tensordot(x, y, axes=dims), + dtype=dtype, + dims=([1, 0], [0, 1]), + ) + self.assertCmpNumpy( + [(3, 4, 5), (4, 5, 6)], + mx.tensordot, + lambda x, y, dims: np.tensordot(x, y, axes=dims), + dtype=dtype, + dims=2, + ) + self.assertCmpNumpy( + [(3, 5, 4, 6), (6, 4, 5, 3)], + mx.tensordot, + lambda x, y, dims: np.tensordot(x, y, axes=dims), + dtype=dtype, + dims=([2, 1, 3], [1, 2, 0]), + ) if __name__ == "__main__":