diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 1969e1028..6c6d34699 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -42,11 +42,27 @@ class TestLinalg(mlx_tests.MLXTestCase): ) assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) - # Test when no axes and no ords are provided - for keepdims in [True, False]: - out_np = np.linalg.norm(x_np, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + # Test only axis provided + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(math.prod(shape)).reshape(shape) + x_np = np.arange(math.prod(shape)).reshape(shape) + + for num_axes in range(1, len(shape)): + for axis in itertools.combinations(range(len(shape)), num_axes): + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + + # Test only ord provided + for shape in [(3,), (2, 3)]: + x_mx = mx.arange(math.prod(shape)).reshape(shape) + x_np = np.arange(math.prod(shape)).reshape(shape) + for o in [None, 1, -1, float("inf"), -float("inf")]: + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) if __name__ == "__main__":