From cc9b2dc3c2f1bfb05dd5e2230aba9faf4ece7d47 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Sun, 17 Dec 2023 02:55:33 +0100 Subject: [PATCH] implemented vector_norm in cpp added linalg to mlx --- mlx/CMakeLists.txt | 1 + mlx/linalg.cpp | 64 ++++++++++++++ mlx/linalg.h | 45 ++++++++++ mlx/mlx.h | 1 + tests/CMakeLists.txt | 1 + tests/linalg_tests.cpp | 189 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 301 insertions(+) create mode 100644 mlx/linalg.cpp create mode 100644 mlx/linalg.h create mode 100644 tests/linalg_tests.cpp diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index bd28537f1..e004fc3d9 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -14,6 +14,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h ) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp new file mode 100644 index 000000000..f40447954 --- /dev/null +++ b/mlx/linalg.cpp @@ -0,0 +1,64 @@ +// Copyright © 2023 Apple Inc. + +#include +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/linalg.h" +#include "mlx/ops.h" + +namespace mlx::core::linalg { + +array vector_norm( + const array& a, + const std::variant& ord, + const std::vector& axes, + bool keepdims, + StreamOrDevice s) { + return std::visit( + overloaded{ + [&](double p) { + if (p >= 1) + return power( + sum(power(abs(a, s), array(p), s), axes, keepdims, s), + array(1.0 / p), + s); + else if (p == 0) + return sum( + where(a != 0, array(1), array(0), s), axes, keepdims, s); + else + throw std::invalid_argument( + "[core.linalg.norm] p norm is defined only for p >= 1."); + }, + [&](const std::string& norm_type) { + if (norm_type == "inf") + return max(abs(a, s), axes, keepdims, s); + else if (norm_type == "-inf") + return min(abs(a, s), axes, keepdims, s); + else + throw std::invalid_argument( + "[core.linalg.norm] Unsupported norm type for a vector."); + }}, + ord); +} +array vector_norm( + const array& a, + const std::variant& ord, + bool keepdims, + StreamOrDevice s) { + return vector_norm( + reshape(a, {static_cast(a.size())}), ord, {-1}, keepdims, s); +} +array vector_norm( + const array& a, + const std::vector& axes, + bool keepdims, + StreamOrDevice s) { + return vector_norm(a, 2.0, axes, keepdims, s); +} +array vector_norm(const array& a, bool keepdims, StreamOrDevice s) { + return vector_norm(a, 2.0, keepdims, s); +} +} // namespace mlx::core::linalg \ No newline at end of file diff --git a/mlx/linalg.h b/mlx/linalg.h new file mode 100644 index 000000000..dc7d8d29d --- /dev/null +++ b/mlx/linalg.h @@ -0,0 +1,45 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "array.h" +#include "device.h" +#include "ops.h" +#include "stream.h" +#include "string.h" + +namespace mlx::core::linalg { + +template +struct overloaded : Ts... { + using Ts::operator()...; +}; +template +overloaded(Ts...) -> overloaded; + +/* +Computes a vector norm. + If axes = {}, x will be flattened before the norm is computed. + Otherwise, the norm is computed over axes and the other dimensions are +treated as batch dimensions. +*/ +array vector_norm( + const array& a, + const std::variant& ord = 2.0, + const std::vector& axes = {}, + bool keepdims = false, + StreamOrDevice s = {}); +array vector_norm( + const array& a, + const std::variant& ord = 2.0, + bool keepdims = false, + StreamOrDevice s = {}); +array vector_norm( + const array& a, + const std::vector& axes = {}, + bool keepdims = false, + StreamOrDevice s = {}); +array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {}); +} // namespace mlx::core::linalg \ No newline at end of file diff --git a/mlx/mlx.h b/mlx/mlx.h index 102d2dde9..8d785c39f 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -6,6 +6,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/device.h" #include "mlx/fft.h" +#include "mlx/linalg.h" #include "mlx/ops.h" #include "mlx/random.h" #include "mlx/stream.h" diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0879aa0f6..dbc499205 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources(tests PRIVATE scheduler_tests.cpp utils_tests.cpp vmap_tests.cpp + linalg_tests.cpp ${METAL_TEST_SOURCES} ) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp new file mode 100644 index 000000000..6cd74357b --- /dev/null +++ b/tests/linalg_tests.cpp @@ -0,0 +1,189 @@ +// Copyright © 2023 Apple Inc. + +#include "doctest/doctest.h" + +#include +#include +#include "mlx/linalg.h" +#include "mlx/mlx.h" + +using namespace mlx::core; +using namespace mlx::core::linalg; + +TEST_CASE("vector_norm") { + // Test 1-norm on a vector + CHECK( + array_equal(vector_norm(ones({3}), 1.0, false), array(3.0)).item()); + CHECK(array_equal(vector_norm(ones({3}), 1.0, true), array({3.0})) + .item()); + // Test 1-norm on a matrix + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, false), array(36)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, true), array({36})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, false), + array(36)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, true), + array({36}, {1, 1})) + .item()); + // Over columns + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, false), + array({3, 12, 21})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, true), + array({3, 12, 21}, {3, 1})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, false), + array({3, 12, 21})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, true), + array({3, 12, 21}, {3, 1})) + .item()); + // Over rows + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, false), + array({9, 12, 15})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, true), + array({9, 12, 15}, {1, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, false), + array({9, 12, 15})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, true), + array({9, 12, 15}, {1, 3})) + .item()); + // Test 1-norm on a 3d tensor + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, false), array(153)) + .item()); + CHECK( + array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, true), array({153})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, false), + array(153)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, true), + array({153}, {1, 1, 1})) + .item()); + // Over last axis + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, false), + array({3, 12, 21, 30, 39, 48}, {2, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, true), + array({3, 12, 21, 30, 39, 48}, {2, 3, 1})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, false), + array({3, 12, 21, 30, 39, 48}, {2, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, true), + array({3, 12, 21, 30, 39, 48}, {2, 3, 1})) + .item()); + // Over middle axis + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, false), + array({9, 12, 15, 36, 39, 42}, {2, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, true), + array({9, 12, 15, 36, 39, 42}, {2, 1, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, false), + array({9, 12, 15, 36, 39, 42}, {2, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, true), + array({9, 12, 15, 36, 39, 42}, {2, 1, 3})) + .item()); + // Over the first axis + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, false), + array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, true), + array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, false), + array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, true), + array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3})) + .item()); + // Test 2-norm on a vector + CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, false), array(5.0)) + .item()); + CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, true), array({5.0})) + .item()); + // Test that 2 is default ord + CHECK(array_equal(vector_norm({3.0, 4.0}, false), array(5.0)).item()); + CHECK(array_equal(vector_norm({3.0, 4.0}, true), array({5.0})).item()); + // Test "inf" norm on a matrix + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", false), array(8.0)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", true), array({8.0})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, false), + array({2, 5, 8})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, true), + array({2, 5, 8}, {3, 1})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, false), + array({6.0, 7.0, 8.0})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, true), + array({6, 7, 8}, {1, 3})) + .item()); + // Test "-inf" norm on a matrix + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", false), array(0)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", true), array({0})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, false), + array({0, 3, 6})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, true), + array({0, 3, 6}, {3, 1})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, false), + array({0, 1, 2})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, true), + array({0, 1, 2}, {1, 3})) + .item()); +} \ No newline at end of file