diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index c2728c738..00cb81dc4 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. +#include #include #include #include @@ -68,6 +69,12 @@ void init_linalg(py::module_& parent_module) { const double ord, const bool keepdims, const StreamOrDevice stream) { + if (std::isinf((float)ord) || std::isinf(ord)) + if (ord > 0) + return norm(a, "inf", {}, keepdims, stream); + else + return norm(a, "-inf", {}, keepdims, stream); + return norm(a, ord, {}, keepdims, stream); }, "a"_a, @@ -82,6 +89,12 @@ void init_linalg(py::module_& parent_module) { const int axis, const bool keepdims, const StreamOrDevice stream) { + if (std::isinf((float)ord) || std::isinf(ord)) + if (ord > 0) + return norm(a, "inf", {axis}, keepdims, stream); + else + return norm(a, "-inf", {axis}, keepdims, stream); + return norm(a, ord, {axis}, keepdims, stream); }, "a"_a, @@ -97,6 +110,12 @@ void init_linalg(py::module_& parent_module) { const std::vector& axis, const bool keepdims, const StreamOrDevice stream) { + if (std::isinf((float)ord) || std::isinf(ord)) + if (ord > 0) + return norm(a, "inf", axis, keepdims, stream); + else + return norm(a, "-inf", axis, keepdims, stream); + return norm(a, ord, axis, keepdims, stream); }, "a"_a, diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 26e9587c5..1969e1028 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import itertools +import math import unittest import mlx.core as mx @@ -10,31 +11,42 @@ import numpy as np class TestLinalg(mlx_tests.MLXTestCase): def test_norm(self): - def check_mx_np(a_mx, a_np): - self.assertTrue(np.allclose(a_np, a_mx, atol=1e-5, rtol=1e-6)) + vector_ords = [None, 0.5, 0, 1, 2, 3, -1, 1, float("inf"), -float("inf")] + matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] - x_mx = mx.arange(18).reshape((2, 3, 3)) - x_np = np.arange(18).reshape((2, 3, 3)) + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(math.prod(shape)).reshape(shape) + x_np = np.arange(math.prod(shape)).reshape(shape) + # Test when at least one axis is provided + for num_axes in range(1, len(shape)): + for axis in itertools.combinations(range(len(shape)), num_axes): + if num_axes == 1: + ords = vector_ords + else: + ords = matrix_ords + for o in ords: + for keepdims in [True, False]: + if o: + out_np = np.linalg.norm( + x_np, ord=o, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, ord=o, axis=axis, keepdims=keepdims + ) + else: + out_np = np.linalg.norm( + x_np, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, axis=axis, keepdims=keepdims + ) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) - for num_axes in range(1, 3): - for axis in itertools.combinations(range(3), num_axes): - if num_axes == 1: - ords = [None, 0.5, 0, 1, 2, 3, -1, 1] - else: - ords = [None, "fro", -1, 1] - for o in ords: - for keepdims in [True, False]: - if o: - out_np = np.linalg.norm( - x_np, ord=o, axis=axis, keepdims=keepdims - ) - out_mx = mx.linalg.norm( - x_mx, ord=o, axis=axis, keepdims=keepdims - ) - else: - out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + # Test when no axes and no ords are provided + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) if __name__ == "__main__":