diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index f40447954..b49713afa 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -3,62 +3,39 @@ #include #include #include +#include #include +#include #include "mlx/array.h" #include "mlx/linalg.h" #include "mlx/ops.h" +#include "utils.h" namespace mlx::core::linalg { -array vector_norm( +inline std::vector get_shape_reducing_over_all_dims(int num_axes) { + std::vector shape(num_axes); + std::iota(shape.begin(), shape.end(), 0); + return shape; +} + +array norm( const array& a, - const std::variant& ord, - const std::vector& axes, + const std::vector& axis, bool keepdims, StreamOrDevice s) { - return std::visit( - overloaded{ - [&](double p) { - if (p >= 1) - return power( - sum(power(abs(a, s), array(p), s), axes, keepdims, s), - array(1.0 / p), - s); - else if (p == 0) - return sum( - where(a != 0, array(1), array(0), s), axes, keepdims, s); - else - throw std::invalid_argument( - "[core.linalg.norm] p norm is defined only for p >= 1."); - }, - [&](const std::string& norm_type) { - if (norm_type == "inf") - return max(abs(a, s), axes, keepdims, s); - else if (norm_type == "-inf") - return min(abs(a, s), axes, keepdims, s); - else - throw std::invalid_argument( - "[core.linalg.norm] Unsupported norm type for a vector."); - }}, - ord); -} -array vector_norm( - const array& a, - const std::variant& ord, - bool keepdims, - StreamOrDevice s) { - return vector_norm( - reshape(a, {static_cast(a.size())}), ord, {-1}, keepdims, s); -} -array vector_norm( - const array& a, - const std::vector& axes, - bool keepdims, - StreamOrDevice s) { - return vector_norm(a, 2.0, axes, keepdims, s); -} -array vector_norm(const array& a, bool keepdims, StreamOrDevice s) { - return vector_norm(a, 2.0, keepdims, s); + auto num_axes = axis.size(); + + if (num_axes == 0 || num_axes == 1 || num_axes == 2) + return sqrt(sum( + abs(a, s) * abs(a, s), + num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()), + keepdims, + s)); + + std::stringstream error_stream; + error_stream << "Invalid axis values" << axis; + throw std::invalid_argument(error_stream.str()); } } // namespace mlx::core::linalg \ No newline at end of file diff --git a/mlx/linalg.h b/mlx/linalg.h index dc7d8d29d..fa9658bbb 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -11,35 +11,9 @@ #include "string.h" namespace mlx::core::linalg { - -template -struct overloaded : Ts... { - using Ts::operator()...; -}; -template -overloaded(Ts...) -> overloaded; - -/* -Computes a vector norm. - If axes = {}, x will be flattened before the norm is computed. - Otherwise, the norm is computed over axes and the other dimensions are -treated as batch dimensions. -*/ -array vector_norm( +array norm( const array& a, - const std::variant& ord = 2.0, - const std::vector& axes = {}, + const std::vector& axis = {}, bool keepdims = false, StreamOrDevice s = {}); -array vector_norm( - const array& a, - const std::variant& ord = 2.0, - bool keepdims = false, - StreamOrDevice s = {}); -array vector_norm( - const array& a, - const std::vector& axes = {}, - bool keepdims = false, - StreamOrDevice s = {}); -array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {}); } // namespace mlx::core::linalg \ No newline at end of file diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 6cd74357b..0b6b31801 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -10,180 +10,88 @@ using namespace mlx::core; using namespace mlx::core::linalg; -TEST_CASE("vector_norm") { - // Test 1-norm on a vector - CHECK( - array_equal(vector_norm(ones({3}), 1.0, false), array(3.0)).item()); - CHECK(array_equal(vector_norm(ones({3}), 1.0, true), array({3.0})) - .item()); - // Test 1-norm on a matrix +TEST_CASE("[mlx.core.linalg.norm] no ord") { + array arr_one_d({1, 2, 3}); + array arr_two_d = reshape(arange(9), {3, 3}); + array arr_three_d = reshape(arange(18), {2, 3, 3}); + + CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item()); + CHECK(array_equal(norm(arr_one_d, {0}), array(sqrt(1 + 4 + 9))).item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, false), array(36)) + norm(arr_two_d), + array(sqrt( + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, true), array({36})) + norm(arr_two_d, {0}), + array( + {sqrt(0 + 3 * 3 + 6 * 6), + sqrt(1 + 4 * 4 + 7 * 7), + sqrt(2 * 2 + 5 * 5 + 8 * 8)})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, false), - array(36)) + norm(arr_two_d, {1}), + array( + {sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8)})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, true), - array({36}, {1, 1})) - .item()); - // Over columns - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, false), - array({3, 12, 21})) + norm(arr_two_d, {0, 1}), + array(sqrt( + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, true), - array({3, 12, 21}, {3, 1})) + norm(arr_three_d, {2}), + array( + { + sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8), + sqrt(9 * 9 + 10 * 10 + 11 * 11), + sqrt(12 * 12 + 13 * 13 + 14 * 14), + sqrt(15 * 15 + 16 * 16 + 17 * 17), + }, + {2, 3})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, false), - array({3, 12, 21})) + norm(arr_three_d, {1}), + array( + { + sqrt(0 + 3 * 3 + 6 * 6), + sqrt(1 + 4 * 4 + 7 * 7), + sqrt(2 * 2 + 5 * 5 + 8 * 8), + sqrt(9 * 9 + 12 * 12 + 15 * 15), + sqrt(10 * 10 + 13 * 13 + 16 * 16), + sqrt(11 * 11 + 14 * 14 + 17 * 17), + }, + {2, 3})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, true), - array({3, 12, 21}, {3, 1})) - .item()); - // Over rows - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, false), - array({9, 12, 15})) + norm(arr_three_d, {0}), + array( + { + sqrt(0 + 9 * 9), + sqrt(1 + 10 * 10), + sqrt(2 * 2 + 11 * 11), + sqrt(3 * 3 + 12 * 12), + sqrt(4 * 4 + 13 * 13), + sqrt(5 * 5 + 14 * 14), + sqrt(6 * 6 + 15 * 15), + sqrt(7 * 7 + 16 * 16), + sqrt(8 * 8 + 17 * 17), + }, + {3, 3})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, true), - array({9, 12, 15}, {1, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, false), - array({9, 12, 15})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, true), - array({9, 12, 15}, {1, 3})) - .item()); - // Test 1-norm on a 3d tensor - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, false), array(153)) - .item()); - CHECK( - array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, true), array({153})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, false), - array(153)) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, true), - array({153}, {1, 1, 1})) - .item()); - // Over last axis - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, false), - array({3, 12, 21, 30, 39, 48}, {2, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, true), - array({3, 12, 21, 30, 39, 48}, {2, 3, 1})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, false), - array({3, 12, 21, 30, 39, 48}, {2, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, true), - array({3, 12, 21, 30, 39, 48}, {2, 3, 1})) - .item()); - // Over middle axis - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, false), - array({9, 12, 15, 36, 39, 42}, {2, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, true), - array({9, 12, 15, 36, 39, 42}, {2, 1, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, false), - array({9, 12, 15, 36, 39, 42}, {2, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, true), - array({9, 12, 15, 36, 39, 42}, {2, 1, 3})) - .item()); - // Over the first axis - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, false), - array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, true), - array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, false), - array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, true), - array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3})) - .item()); - // Test 2-norm on a vector - CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, false), array(5.0)) - .item()); - CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, true), array({5.0})) - .item()); - // Test that 2 is default ord - CHECK(array_equal(vector_norm({3.0, 4.0}, false), array(5.0)).item()); - CHECK(array_equal(vector_norm({3.0, 4.0}, true), array({5.0})).item()); - // Test "inf" norm on a matrix - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", false), array(8.0)) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", true), array({8.0})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, false), - array({2, 5, 8})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, true), - array({2, 5, 8}, {3, 1})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, false), - array({6.0, 7.0, 8.0})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, true), - array({6, 7, 8}, {1, 3})) - .item()); - // Test "-inf" norm on a matrix - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", false), array(0)) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", true), array({0})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, false), - array({0, 3, 6})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, true), - array({0, 3, 6}, {3, 1})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, false), - array({0, 1, 2})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, true), - array({0, 1, 2}, {1, 3})) + norm(arr_three_d, {1, 2}), + array( + {sqrt( + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + + 8 * 8), + sqrt( + 9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 + + 15 * 15 + 16 * 16 + 17 * 17)}, + {2})) .item()); } \ No newline at end of file