diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index d96dd8a2d..de5ccfcf3 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -44,7 +44,9 @@ void init_linalg(py::module_& parent_module) { return norm( a, "inf", - get_reduce_axes(axis, a.ndim()), + std::holds_alternative(axis) + ? std::vector() + : get_reduce_axes(axis, a.ndim()), keepdims, stream); } @@ -56,15 +58,32 @@ void init_linalg(py::module_& parent_module) { stream); } return norm( - a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream); + a, + p, + std::holds_alternative(axis) + ? std::vector() + : get_reduce_axes(axis, a.ndim()), + keepdims, + stream); }, [&](const std::string& p) { return norm( - a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream); + a, + p, + std::holds_alternative(axis) + ? std::vector() + : get_reduce_axes(axis, a.ndim()), + keepdims, + stream); }, [&](const std::monostate _) { return norm( - a, get_reduce_axes(axis, a.ndim()), keepdims, stream); + a, + std::holds_alternative(axis) + ? std::vector() + : get_reduce_axes(axis, a.ndim()), + keepdims, + stream); }}, ord); }, diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 6c6d34699..ce1926de0 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -24,6 +24,12 @@ class TestLinalg(mlx_tests.MLXTestCase): ords = vector_ords else: ords = matrix_ords + for keepdims in [True, False]: + # Test axis provided, no ord provided + out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + # Test both ord and axis provided for o in ords: for keepdims in [True, False]: if o: @@ -64,6 +70,16 @@ class TestLinalg(mlx_tests.MLXTestCase): out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + # Test no ord and no axis provided + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(math.prod(shape)).reshape(shape) + x_np = np.arange(math.prod(shape)).reshape(shape) + for o in [None, 1, -1, float("inf"), -float("inf")]: + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + if __name__ == "__main__": unittest.main()