mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
renamed vector_norm to norm, implemented norm without provided ord
This commit is contained in:
parent
24da85025f
commit
05203ecd78
@ -3,62 +3,39 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
array vector_norm(
|
||||
inline std::vector<int> get_shape_reducing_over_all_dims(int num_axes) {
|
||||
std::vector<int> shape(num_axes);
|
||||
std::iota(shape.begin(), shape.end(), 0);
|
||||
return shape;
|
||||
}
|
||||
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::variant<double, std::string>& ord,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
return std::visit(
|
||||
overloaded{
|
||||
[&](double p) {
|
||||
if (p >= 1)
|
||||
return power(
|
||||
sum(power(abs(a, s), array(p), s), axes, keepdims, s),
|
||||
array(1.0 / p),
|
||||
s);
|
||||
else if (p == 0)
|
||||
return sum(
|
||||
where(a != 0, array(1), array(0), s), axes, keepdims, s);
|
||||
else
|
||||
throw std::invalid_argument(
|
||||
"[core.linalg.norm] p norm is defined only for p >= 1.");
|
||||
},
|
||||
[&](const std::string& norm_type) {
|
||||
if (norm_type == "inf")
|
||||
return max(abs(a, s), axes, keepdims, s);
|
||||
else if (norm_type == "-inf")
|
||||
return min(abs(a, s), axes, keepdims, s);
|
||||
else
|
||||
throw std::invalid_argument(
|
||||
"[core.linalg.norm] Unsupported norm type for a vector.");
|
||||
}},
|
||||
ord);
|
||||
}
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::variant<double, std::string>& ord,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
return vector_norm(
|
||||
reshape(a, {static_cast<int>(a.size())}), ord, {-1}, keepdims, s);
|
||||
}
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::vector<int>& axes,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
return vector_norm(a, 2.0, axes, keepdims, s);
|
||||
}
|
||||
array vector_norm(const array& a, bool keepdims, StreamOrDevice s) {
|
||||
return vector_norm(a, 2.0, keepdims, s);
|
||||
auto num_axes = axis.size();
|
||||
|
||||
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
||||
return sqrt(sum(
|
||||
abs(a, s) * abs(a, s),
|
||||
num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()),
|
||||
keepdims,
|
||||
s));
|
||||
|
||||
std::stringstream error_stream;
|
||||
error_stream << "Invalid axis values" << axis;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
} // namespace mlx::core::linalg
|
30
mlx/linalg.h
30
mlx/linalg.h
@ -11,35 +11,9 @@
|
||||
#include "string.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
template <class... Ts>
|
||||
struct overloaded : Ts... {
|
||||
using Ts::operator()...;
|
||||
};
|
||||
template <class... Ts>
|
||||
overloaded(Ts...) -> overloaded<Ts...>;
|
||||
|
||||
/*
|
||||
Computes a vector norm.
|
||||
If axes = {}, x will be flattened before the norm is computed.
|
||||
Otherwise, the norm is computed over axes and the other dimensions are
|
||||
treated as batch dimensions.
|
||||
*/
|
||||
array vector_norm(
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::variant<double, std::string>& ord = 2.0,
|
||||
const std::vector<int>& axes = {},
|
||||
const std::vector<int>& axis = {},
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::variant<double, std::string>& ord = 2.0,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
array vector_norm(
|
||||
const array& a,
|
||||
const std::vector<int>& axes = {},
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {});
|
||||
} // namespace mlx::core::linalg
|
@ -10,180 +10,88 @@
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
TEST_CASE("vector_norm") {
|
||||
// Test 1-norm on a vector
|
||||
CHECK(
|
||||
array_equal(vector_norm(ones({3}), 1.0, false), array(3.0)).item<bool>());
|
||||
CHECK(array_equal(vector_norm(ones({3}), 1.0, true), array({3.0}))
|
||||
.item<bool>());
|
||||
// Test 1-norm on a matrix
|
||||
TEST_CASE("[mlx.core.linalg.norm] no ord") {
|
||||
array arr_one_d({1, 2, 3});
|
||||
array arr_two_d = reshape(arange(9), {3, 3});
|
||||
array arr_three_d = reshape(arange(18), {2, 3, 3});
|
||||
|
||||
CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item<bool>());
|
||||
CHECK(array_equal(norm(arr_one_d, {0}), array(sqrt(1 + 4 + 9))).item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, false), array(36))
|
||||
norm(arr_two_d),
|
||||
array(sqrt(
|
||||
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, true), array({36}))
|
||||
norm(arr_two_d, {0}),
|
||||
array(
|
||||
{sqrt(0 + 3 * 3 + 6 * 6),
|
||||
sqrt(1 + 4 * 4 + 7 * 7),
|
||||
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, false),
|
||||
array(36))
|
||||
norm(arr_two_d, {1}),
|
||||
array(
|
||||
{sqrt(0 + 1 + 2 * 2),
|
||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, true),
|
||||
array({36}, {1, 1}))
|
||||
.item<bool>());
|
||||
// Over columns
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, false),
|
||||
array({3, 12, 21}))
|
||||
norm(arr_two_d, {0, 1}),
|
||||
array(sqrt(
|
||||
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, true),
|
||||
array({3, 12, 21}, {3, 1}))
|
||||
norm(arr_three_d, {2}),
|
||||
array(
|
||||
{
|
||||
sqrt(0 + 1 + 2 * 2),
|
||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||
sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
||||
sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
||||
sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
||||
sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
||||
},
|
||||
{2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, false),
|
||||
array({3, 12, 21}))
|
||||
norm(arr_three_d, {1}),
|
||||
array(
|
||||
{
|
||||
sqrt(0 + 3 * 3 + 6 * 6),
|
||||
sqrt(1 + 4 * 4 + 7 * 7),
|
||||
sqrt(2 * 2 + 5 * 5 + 8 * 8),
|
||||
sqrt(9 * 9 + 12 * 12 + 15 * 15),
|
||||
sqrt(10 * 10 + 13 * 13 + 16 * 16),
|
||||
sqrt(11 * 11 + 14 * 14 + 17 * 17),
|
||||
},
|
||||
{2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, true),
|
||||
array({3, 12, 21}, {3, 1}))
|
||||
.item<bool>());
|
||||
// Over rows
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, false),
|
||||
array({9, 12, 15}))
|
||||
norm(arr_three_d, {0}),
|
||||
array(
|
||||
{
|
||||
sqrt(0 + 9 * 9),
|
||||
sqrt(1 + 10 * 10),
|
||||
sqrt(2 * 2 + 11 * 11),
|
||||
sqrt(3 * 3 + 12 * 12),
|
||||
sqrt(4 * 4 + 13 * 13),
|
||||
sqrt(5 * 5 + 14 * 14),
|
||||
sqrt(6 * 6 + 15 * 15),
|
||||
sqrt(7 * 7 + 16 * 16),
|
||||
sqrt(8 * 8 + 17 * 17),
|
||||
},
|
||||
{3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, true),
|
||||
array({9, 12, 15}, {1, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, false),
|
||||
array({9, 12, 15}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, true),
|
||||
array({9, 12, 15}, {1, 3}))
|
||||
.item<bool>());
|
||||
// Test 1-norm on a 3d tensor
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, false), array(153))
|
||||
.item<bool>());
|
||||
CHECK(
|
||||
array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, true), array({153}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, false),
|
||||
array(153))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, true),
|
||||
array({153}, {1, 1, 1}))
|
||||
.item<bool>());
|
||||
// Over last axis
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, false),
|
||||
array({3, 12, 21, 30, 39, 48}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, true),
|
||||
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, false),
|
||||
array({3, 12, 21, 30, 39, 48}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, true),
|
||||
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
|
||||
.item<bool>());
|
||||
// Over middle axis
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, false),
|
||||
array({9, 12, 15, 36, 39, 42}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, true),
|
||||
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, false),
|
||||
array({9, 12, 15, 36, 39, 42}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, true),
|
||||
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
|
||||
.item<bool>());
|
||||
// Over the first axis
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, false),
|
||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, true),
|
||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, false),
|
||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, true),
|
||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
|
||||
.item<bool>());
|
||||
// Test 2-norm on a vector
|
||||
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, false), array(5.0))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, true), array({5.0}))
|
||||
.item<bool>());
|
||||
// Test that 2 is default ord
|
||||
CHECK(array_equal(vector_norm({3.0, 4.0}, false), array(5.0)).item<bool>());
|
||||
CHECK(array_equal(vector_norm({3.0, 4.0}, true), array({5.0})).item<bool>());
|
||||
// Test "inf" norm on a matrix
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", false), array(8.0))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", true), array({8.0}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, false),
|
||||
array({2, 5, 8}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, true),
|
||||
array({2, 5, 8}, {3, 1}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, false),
|
||||
array({6.0, 7.0, 8.0}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, true),
|
||||
array({6, 7, 8}, {1, 3}))
|
||||
.item<bool>());
|
||||
// Test "-inf" norm on a matrix
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", false), array(0))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", true), array({0}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, false),
|
||||
array({0, 3, 6}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, true),
|
||||
array({0, 3, 6}, {3, 1}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, false),
|
||||
array({0, 1, 2}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, true),
|
||||
array({0, 1, 2}, {1, 3}))
|
||||
norm(arr_three_d, {1, 2}),
|
||||
array(
|
||||
{sqrt(
|
||||
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
|
||||
8 * 8),
|
||||
sqrt(
|
||||
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
|
||||
15 * 15 + 16 * 16 + 17 * 17)},
|
||||
{2}))
|
||||
.item<bool>());
|
||||
}
|
Loading…
Reference in New Issue
Block a user