implemented vector_norm in cpp

added linalg to mlx
This commit is contained in:
Gabrijel Boduljak 2023-12-17 02:55:33 +01:00 committed by Awni Hannun
parent 447bc089b9
commit cc9b2dc3c2
6 changed files with 301 additions and 0 deletions

View File

@ -14,6 +14,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
) )

64
mlx/linalg.cpp Normal file
View File

@ -0,0 +1,64 @@
// Copyright © 2023 Apple Inc.
#include <iostream>
#include <numeric>
#include <set>
#include <variant>
#include "mlx/array.h"
#include "mlx/linalg.h"
#include "mlx/ops.h"
namespace mlx::core::linalg {
array vector_norm(
const array& a,
const std::variant<double, std::string>& ord,
const std::vector<int>& axes,
bool keepdims,
StreamOrDevice s) {
return std::visit(
overloaded{
[&](double p) {
if (p >= 1)
return power(
sum(power(abs(a, s), array(p), s), axes, keepdims, s),
array(1.0 / p),
s);
else if (p == 0)
return sum(
where(a != 0, array(1), array(0), s), axes, keepdims, s);
else
throw std::invalid_argument(
"[core.linalg.norm] p norm is defined only for p >= 1.");
},
[&](const std::string& norm_type) {
if (norm_type == "inf")
return max(abs(a, s), axes, keepdims, s);
else if (norm_type == "-inf")
return min(abs(a, s), axes, keepdims, s);
else
throw std::invalid_argument(
"[core.linalg.norm] Unsupported norm type for a vector.");
}},
ord);
}
array vector_norm(
const array& a,
const std::variant<double, std::string>& ord,
bool keepdims,
StreamOrDevice s) {
return vector_norm(
reshape(a, {static_cast<int>(a.size())}), ord, {-1}, keepdims, s);
}
array vector_norm(
const array& a,
const std::vector<int>& axes,
bool keepdims,
StreamOrDevice s) {
return vector_norm(a, 2.0, axes, keepdims, s);
}
array vector_norm(const array& a, bool keepdims, StreamOrDevice s) {
return vector_norm(a, 2.0, keepdims, s);
}
} // namespace mlx::core::linalg

45
mlx/linalg.h Normal file
View File

@ -0,0 +1,45 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "ops.h"
#include "stream.h"
#include "string.h"
namespace mlx::core::linalg {
template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
/*
Computes a vector norm.
If axes = {}, x will be flattened before the norm is computed.
Otherwise, the norm is computed over axes and the other dimensions are
treated as batch dimensions.
*/
array vector_norm(
const array& a,
const std::variant<double, std::string>& ord = 2.0,
const std::vector<int>& axes = {},
bool keepdims = false,
StreamOrDevice s = {});
array vector_norm(
const array& a,
const std::variant<double, std::string>& ord = 2.0,
bool keepdims = false,
StreamOrDevice s = {});
array vector_norm(
const array& a,
const std::vector<int>& axes = {},
bool keepdims = false,
StreamOrDevice s = {});
array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {});
} // namespace mlx::core::linalg

View File

@ -6,6 +6,7 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/fft.h" #include "mlx/fft.h"
#include "mlx/linalg.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/random.h" #include "mlx/random.h"
#include "mlx/stream.h" #include "mlx/stream.h"

View File

@ -31,6 +31,7 @@ target_sources(tests PRIVATE
scheduler_tests.cpp scheduler_tests.cpp
utils_tests.cpp utils_tests.cpp
vmap_tests.cpp vmap_tests.cpp
linalg_tests.cpp
${METAL_TEST_SOURCES} ${METAL_TEST_SOURCES}
) )

189
tests/linalg_tests.cpp Normal file
View File

@ -0,0 +1,189 @@
// Copyright © 2023 Apple Inc.
#include "doctest/doctest.h"
#include <cmath>
#include <iostream>
#include "mlx/linalg.h"
#include "mlx/mlx.h"
using namespace mlx::core;
using namespace mlx::core::linalg;
TEST_CASE("vector_norm") {
// Test 1-norm on a vector
CHECK(
array_equal(vector_norm(ones({3}), 1.0, false), array(3.0)).item<bool>());
CHECK(array_equal(vector_norm(ones({3}), 1.0, true), array({3.0}))
.item<bool>());
// Test 1-norm on a matrix
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, false), array(36))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, true), array({36}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, false),
array(36))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, true),
array({36}, {1, 1}))
.item<bool>());
// Over columns
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, false),
array({3, 12, 21}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, true),
array({3, 12, 21}, {3, 1}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, false),
array({3, 12, 21}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, true),
array({3, 12, 21}, {3, 1}))
.item<bool>());
// Over rows
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, false),
array({9, 12, 15}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, true),
array({9, 12, 15}, {1, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, false),
array({9, 12, 15}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, true),
array({9, 12, 15}, {1, 3}))
.item<bool>());
// Test 1-norm on a 3d tensor
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, false), array(153))
.item<bool>());
CHECK(
array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, true), array({153}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, false),
array(153))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, true),
array({153}, {1, 1, 1}))
.item<bool>());
// Over last axis
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, false),
array({3, 12, 21, 30, 39, 48}, {2, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, true),
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, false),
array({3, 12, 21, 30, 39, 48}, {2, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, true),
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
.item<bool>());
// Over middle axis
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, false),
array({9, 12, 15, 36, 39, 42}, {2, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, true),
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, false),
array({9, 12, 15, 36, 39, 42}, {2, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, true),
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
.item<bool>());
// Over the first axis
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, false),
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, true),
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, false),
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, true),
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
.item<bool>());
// Test 2-norm on a vector
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, false), array(5.0))
.item<bool>());
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, true), array({5.0}))
.item<bool>());
// Test that 2 is default ord
CHECK(array_equal(vector_norm({3.0, 4.0}, false), array(5.0)).item<bool>());
CHECK(array_equal(vector_norm({3.0, 4.0}, true), array({5.0})).item<bool>());
// Test "inf" norm on a matrix
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", false), array(8.0))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", true), array({8.0}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, false),
array({2, 5, 8}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, true),
array({2, 5, 8}, {3, 1}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, false),
array({6.0, 7.0, 8.0}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, true),
array({6, 7, 8}, {1, 3}))
.item<bool>());
// Test "-inf" norm on a matrix
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", false), array(0))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", true), array({0}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, false),
array({0, 3, 6}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, true),
array({0, 3, 6}, {3, 1}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, false),
array({0, 1, 2}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, true),
array({0, 1, 2}, {1, 3}))
.item<bool>());
}