mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
implemented vector_norm in cpp
added linalg to mlx
This commit is contained in:
parent
447bc089b9
commit
cc9b2dc3c2
@ -14,6 +14,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||
)
|
||||
|
||||
|
64
mlx/linalg.cpp
Normal file
64
mlx/linalg.cpp
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <variant>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/ops.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::variant<double, std::string>& ord,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
return std::visit(
|
||||
overloaded{
|
||||
[&](double p) {
|
||||
if (p >= 1)
|
||||
return power(
|
||||
sum(power(abs(a, s), array(p), s), axes, keepdims, s),
|
||||
array(1.0 / p),
|
||||
s);
|
||||
else if (p == 0)
|
||||
return sum(
|
||||
where(a != 0, array(1), array(0), s), axes, keepdims, s);
|
||||
else
|
||||
throw std::invalid_argument(
|
||||
"[core.linalg.norm] p norm is defined only for p >= 1.");
|
||||
},
|
||||
[&](const std::string& norm_type) {
|
||||
if (norm_type == "inf")
|
||||
return max(abs(a, s), axes, keepdims, s);
|
||||
else if (norm_type == "-inf")
|
||||
return min(abs(a, s), axes, keepdims, s);
|
||||
else
|
||||
throw std::invalid_argument(
|
||||
"[core.linalg.norm] Unsupported norm type for a vector.");
|
||||
}},
|
||||
ord);
|
||||
}
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::variant<double, std::string>& ord,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
return vector_norm(
|
||||
reshape(a, {static_cast<int>(a.size())}), ord, {-1}, keepdims, s);
|
||||
}
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
return vector_norm(a, 2.0, axes, keepdims, s);
|
||||
}
|
||||
array vector_norm(const array& a, bool keepdims, StreamOrDevice s) {
|
||||
return vector_norm(a, 2.0, keepdims, s);
|
||||
}
|
||||
} // namespace mlx::core::linalg
|
45
mlx/linalg.h
Normal file
45
mlx/linalg.h
Normal file
@ -0,0 +1,45 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "ops.h"
|
||||
#include "stream.h"
|
||||
#include "string.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
template <class... Ts>
|
||||
struct overloaded : Ts... {
|
||||
using Ts::operator()...;
|
||||
};
|
||||
template <class... Ts>
|
||||
overloaded(Ts...) -> overloaded<Ts...>;
|
||||
|
||||
/*
|
||||
Computes a vector norm.
|
||||
If axes = {}, x will be flattened before the norm is computed.
|
||||
Otherwise, the norm is computed over axes and the other dimensions are
|
||||
treated as batch dimensions.
|
||||
*/
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::variant<double, std::string>& ord = 2.0,
|
||||
const std::vector<int>& axes = {},
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::variant<double, std::string>& ord = 2.0,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::vector<int>& axes = {},
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {});
|
||||
} // namespace mlx::core::linalg
|
@ -6,6 +6,7 @@
|
||||
#include "mlx/backend/metal/metal.h"
|
||||
#include "mlx/device.h"
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/random.h"
|
||||
#include "mlx/stream.h"
|
||||
|
@ -31,6 +31,7 @@ target_sources(tests PRIVATE
|
||||
scheduler_tests.cpp
|
||||
utils_tests.cpp
|
||||
vmap_tests.cpp
|
||||
linalg_tests.cpp
|
||||
${METAL_TEST_SOURCES}
|
||||
)
|
||||
|
||||
|
189
tests/linalg_tests.cpp
Normal file
189
tests/linalg_tests.cpp
Normal file
@ -0,0 +1,189 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
TEST_CASE("vector_norm") {
|
||||
// Test 1-norm on a vector
|
||||
CHECK(
|
||||
array_equal(vector_norm(ones({3}), 1.0, false), array(3.0)).item<bool>());
|
||||
CHECK(array_equal(vector_norm(ones({3}), 1.0, true), array({3.0}))
|
||||
.item<bool>());
|
||||
// Test 1-norm on a matrix
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, false), array(36))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, true), array({36}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, false),
|
||||
array(36))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, true),
|
||||
array({36}, {1, 1}))
|
||||
.item<bool>());
|
||||
// Over columns
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, false),
|
||||
array({3, 12, 21}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, true),
|
||||
array({3, 12, 21}, {3, 1}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, false),
|
||||
array({3, 12, 21}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, true),
|
||||
array({3, 12, 21}, {3, 1}))
|
||||
.item<bool>());
|
||||
// Over rows
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, false),
|
||||
array({9, 12, 15}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, true),
|
||||
array({9, 12, 15}, {1, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, false),
|
||||
array({9, 12, 15}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, true),
|
||||
array({9, 12, 15}, {1, 3}))
|
||||
.item<bool>());
|
||||
// Test 1-norm on a 3d tensor
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, false), array(153))
|
||||
.item<bool>());
|
||||
CHECK(
|
||||
array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, true), array({153}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, false),
|
||||
array(153))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, true),
|
||||
array({153}, {1, 1, 1}))
|
||||
.item<bool>());
|
||||
// Over last axis
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, false),
|
||||
array({3, 12, 21, 30, 39, 48}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, true),
|
||||
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, false),
|
||||
array({3, 12, 21, 30, 39, 48}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, true),
|
||||
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
|
||||
.item<bool>());
|
||||
// Over middle axis
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, false),
|
||||
array({9, 12, 15, 36, 39, 42}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, true),
|
||||
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, false),
|
||||
array({9, 12, 15, 36, 39, 42}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, true),
|
||||
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
|
||||
.item<bool>());
|
||||
// Over the first axis
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, false),
|
||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, true),
|
||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, false),
|
||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, true),
|
||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
|
||||
.item<bool>());
|
||||
// Test 2-norm on a vector
|
||||
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, false), array(5.0))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, true), array({5.0}))
|
||||
.item<bool>());
|
||||
// Test that 2 is default ord
|
||||
CHECK(array_equal(vector_norm({3.0, 4.0}, false), array(5.0)).item<bool>());
|
||||
CHECK(array_equal(vector_norm({3.0, 4.0}, true), array({5.0})).item<bool>());
|
||||
// Test "inf" norm on a matrix
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", false), array(8.0))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", true), array({8.0}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, false),
|
||||
array({2, 5, 8}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, true),
|
||||
array({2, 5, 8}, {3, 1}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, false),
|
||||
array({6.0, 7.0, 8.0}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, true),
|
||||
array({6, 7, 8}, {1, 3}))
|
||||
.item<bool>());
|
||||
// Test "-inf" norm on a matrix
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", false), array(0))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", true), array({0}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, false),
|
||||
array({0, 3, 6}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, true),
|
||||
array({0, 3, 6}, {3, 1}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, false),
|
||||
array({0, 1, 2}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, true),
|
||||
array({0, 1, 2}, {1, 3}))
|
||||
.item<bool>());
|
||||
}
|
Loading…
Reference in New Issue
Block a user