mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
renamed vector_norm to norm, implemented norm without provided ord
This commit is contained in:
parent
24da85025f
commit
05203ecd78
@ -3,62 +3,39 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
array vector_norm(
|
inline std::vector<int> get_shape_reducing_over_all_dims(int num_axes) {
|
||||||
|
std::vector<int> shape(num_axes);
|
||||||
|
std::iota(shape.begin(), shape.end(), 0);
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::variant<double, std::string>& ord,
|
const std::vector<int>& axis,
|
||||||
const std::vector<int>& axes,
|
|
||||||
bool keepdims,
|
bool keepdims,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
return std::visit(
|
auto num_axes = axis.size();
|
||||||
overloaded{
|
|
||||||
[&](double p) {
|
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
||||||
if (p >= 1)
|
return sqrt(sum(
|
||||||
return power(
|
abs(a, s) * abs(a, s),
|
||||||
sum(power(abs(a, s), array(p), s), axes, keepdims, s),
|
num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()),
|
||||||
array(1.0 / p),
|
keepdims,
|
||||||
s);
|
s));
|
||||||
else if (p == 0)
|
|
||||||
return sum(
|
std::stringstream error_stream;
|
||||||
where(a != 0, array(1), array(0), s), axes, keepdims, s);
|
error_stream << "Invalid axis values" << axis;
|
||||||
else
|
throw std::invalid_argument(error_stream.str());
|
||||||
throw std::invalid_argument(
|
|
||||||
"[core.linalg.norm] p norm is defined only for p >= 1.");
|
|
||||||
},
|
|
||||||
[&](const std::string& norm_type) {
|
|
||||||
if (norm_type == "inf")
|
|
||||||
return max(abs(a, s), axes, keepdims, s);
|
|
||||||
else if (norm_type == "-inf")
|
|
||||||
return min(abs(a, s), axes, keepdims, s);
|
|
||||||
else
|
|
||||||
throw std::invalid_argument(
|
|
||||||
"[core.linalg.norm] Unsupported norm type for a vector.");
|
|
||||||
}},
|
|
||||||
ord);
|
|
||||||
}
|
|
||||||
array vector_norm(
|
|
||||||
const array& a,
|
|
||||||
const std::variant<double, std::string>& ord,
|
|
||||||
bool keepdims,
|
|
||||||
StreamOrDevice s) {
|
|
||||||
return vector_norm(
|
|
||||||
reshape(a, {static_cast<int>(a.size())}), ord, {-1}, keepdims, s);
|
|
||||||
}
|
|
||||||
array vector_norm(
|
|
||||||
const array& a,
|
|
||||||
const std::vector<int>& axes,
|
|
||||||
bool keepdims,
|
|
||||||
StreamOrDevice s) {
|
|
||||||
return vector_norm(a, 2.0, axes, keepdims, s);
|
|
||||||
}
|
|
||||||
array vector_norm(const array& a, bool keepdims, StreamOrDevice s) {
|
|
||||||
return vector_norm(a, 2.0, keepdims, s);
|
|
||||||
}
|
}
|
||||||
} // namespace mlx::core::linalg
|
} // namespace mlx::core::linalg
|
30
mlx/linalg.h
30
mlx/linalg.h
@ -11,35 +11,9 @@
|
|||||||
#include "string.h"
|
#include "string.h"
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
array norm(
|
||||||
template <class... Ts>
|
|
||||||
struct overloaded : Ts... {
|
|
||||||
using Ts::operator()...;
|
|
||||||
};
|
|
||||||
template <class... Ts>
|
|
||||||
overloaded(Ts...) -> overloaded<Ts...>;
|
|
||||||
|
|
||||||
/*
|
|
||||||
Computes a vector norm.
|
|
||||||
If axes = {}, x will be flattened before the norm is computed.
|
|
||||||
Otherwise, the norm is computed over axes and the other dimensions are
|
|
||||||
treated as batch dimensions.
|
|
||||||
*/
|
|
||||||
array vector_norm(
|
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::variant<double, std::string>& ord = 2.0,
|
const std::vector<int>& axis = {},
|
||||||
const std::vector<int>& axes = {},
|
|
||||||
bool keepdims = false,
|
bool keepdims = false,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
array vector_norm(
|
|
||||||
const array& a,
|
|
||||||
const std::variant<double, std::string>& ord = 2.0,
|
|
||||||
bool keepdims = false,
|
|
||||||
StreamOrDevice s = {});
|
|
||||||
array vector_norm(
|
|
||||||
const array& a,
|
|
||||||
const std::vector<int>& axes = {},
|
|
||||||
bool keepdims = false,
|
|
||||||
StreamOrDevice s = {});
|
|
||||||
array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {});
|
|
||||||
} // namespace mlx::core::linalg
|
} // namespace mlx::core::linalg
|
@ -10,180 +10,88 @@
|
|||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
using namespace mlx::core::linalg;
|
using namespace mlx::core::linalg;
|
||||||
|
|
||||||
TEST_CASE("vector_norm") {
|
TEST_CASE("[mlx.core.linalg.norm] no ord") {
|
||||||
// Test 1-norm on a vector
|
array arr_one_d({1, 2, 3});
|
||||||
CHECK(
|
array arr_two_d = reshape(arange(9), {3, 3});
|
||||||
array_equal(vector_norm(ones({3}), 1.0, false), array(3.0)).item<bool>());
|
array arr_three_d = reshape(arange(18), {2, 3, 3});
|
||||||
CHECK(array_equal(vector_norm(ones({3}), 1.0, true), array({3.0}))
|
|
||||||
.item<bool>());
|
CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item<bool>());
|
||||||
// Test 1-norm on a matrix
|
CHECK(array_equal(norm(arr_one_d, {0}), array(sqrt(1 + 4 + 9))).item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(array_equal(
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, false), array(36))
|
norm(arr_two_d),
|
||||||
|
array(sqrt(
|
||||||
|
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(array_equal(
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, true), array({36}))
|
norm(arr_two_d, {0}),
|
||||||
|
array(
|
||||||
|
{sqrt(0 + 3 * 3 + 6 * 6),
|
||||||
|
sqrt(1 + 4 * 4 + 7 * 7),
|
||||||
|
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(array_equal(
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, false),
|
norm(arr_two_d, {1}),
|
||||||
array(36))
|
array(
|
||||||
|
{sqrt(0 + 1 + 2 * 2),
|
||||||
|
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
|
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(array_equal(
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, true),
|
norm(arr_two_d, {0, 1}),
|
||||||
array({36}, {1, 1}))
|
array(sqrt(
|
||||||
.item<bool>());
|
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
|
||||||
// Over columns
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, false),
|
|
||||||
array({3, 12, 21}))
|
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(array_equal(
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, true),
|
norm(arr_three_d, {2}),
|
||||||
array({3, 12, 21}, {3, 1}))
|
array(
|
||||||
|
{
|
||||||
|
sqrt(0 + 1 + 2 * 2),
|
||||||
|
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
|
sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
||||||
|
sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
||||||
|
sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
||||||
|
sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
||||||
|
},
|
||||||
|
{2, 3}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(array_equal(
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, false),
|
norm(arr_three_d, {1}),
|
||||||
array({3, 12, 21}))
|
array(
|
||||||
|
{
|
||||||
|
sqrt(0 + 3 * 3 + 6 * 6),
|
||||||
|
sqrt(1 + 4 * 4 + 7 * 7),
|
||||||
|
sqrt(2 * 2 + 5 * 5 + 8 * 8),
|
||||||
|
sqrt(9 * 9 + 12 * 12 + 15 * 15),
|
||||||
|
sqrt(10 * 10 + 13 * 13 + 16 * 16),
|
||||||
|
sqrt(11 * 11 + 14 * 14 + 17 * 17),
|
||||||
|
},
|
||||||
|
{2, 3}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(array_equal(
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, true),
|
norm(arr_three_d, {0}),
|
||||||
array({3, 12, 21}, {3, 1}))
|
array(
|
||||||
.item<bool>());
|
{
|
||||||
// Over rows
|
sqrt(0 + 9 * 9),
|
||||||
CHECK(array_equal(
|
sqrt(1 + 10 * 10),
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, false),
|
sqrt(2 * 2 + 11 * 11),
|
||||||
array({9, 12, 15}))
|
sqrt(3 * 3 + 12 * 12),
|
||||||
|
sqrt(4 * 4 + 13 * 13),
|
||||||
|
sqrt(5 * 5 + 14 * 14),
|
||||||
|
sqrt(6 * 6 + 15 * 15),
|
||||||
|
sqrt(7 * 7 + 16 * 16),
|
||||||
|
sqrt(8 * 8 + 17 * 17),
|
||||||
|
},
|
||||||
|
{3, 3}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(array_equal(
|
CHECK(array_equal(
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, true),
|
norm(arr_three_d, {1, 2}),
|
||||||
array({9, 12, 15}, {1, 3}))
|
array(
|
||||||
.item<bool>());
|
{sqrt(
|
||||||
CHECK(array_equal(
|
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, false),
|
8 * 8),
|
||||||
array({9, 12, 15}))
|
sqrt(
|
||||||
.item<bool>());
|
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
|
||||||
CHECK(array_equal(
|
15 * 15 + 16 * 16 + 17 * 17)},
|
||||||
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, true),
|
{2}))
|
||||||
array({9, 12, 15}, {1, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
// Test 1-norm on a 3d tensor
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, false), array(153))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(
|
|
||||||
array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, true), array({153}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, false),
|
|
||||||
array(153))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, true),
|
|
||||||
array({153}, {1, 1, 1}))
|
|
||||||
.item<bool>());
|
|
||||||
// Over last axis
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, false),
|
|
||||||
array({3, 12, 21, 30, 39, 48}, {2, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, true),
|
|
||||||
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, false),
|
|
||||||
array({3, 12, 21, 30, 39, 48}, {2, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, true),
|
|
||||||
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
|
|
||||||
.item<bool>());
|
|
||||||
// Over middle axis
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, false),
|
|
||||||
array({9, 12, 15, 36, 39, 42}, {2, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, true),
|
|
||||||
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, false),
|
|
||||||
array({9, 12, 15, 36, 39, 42}, {2, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, true),
|
|
||||||
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
// Over the first axis
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, false),
|
|
||||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, true),
|
|
||||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, false),
|
|
||||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, true),
|
|
||||||
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
// Test 2-norm on a vector
|
|
||||||
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, false), array(5.0))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, true), array({5.0}))
|
|
||||||
.item<bool>());
|
|
||||||
// Test that 2 is default ord
|
|
||||||
CHECK(array_equal(vector_norm({3.0, 4.0}, false), array(5.0)).item<bool>());
|
|
||||||
CHECK(array_equal(vector_norm({3.0, 4.0}, true), array({5.0})).item<bool>());
|
|
||||||
// Test "inf" norm on a matrix
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "inf", false), array(8.0))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "inf", true), array({8.0}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, false),
|
|
||||||
array({2, 5, 8}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, true),
|
|
||||||
array({2, 5, 8}, {3, 1}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, false),
|
|
||||||
array({6.0, 7.0, 8.0}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, true),
|
|
||||||
array({6, 7, 8}, {1, 3}))
|
|
||||||
.item<bool>());
|
|
||||||
// Test "-inf" norm on a matrix
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", false), array(0))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", true), array({0}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, false),
|
|
||||||
array({0, 3, 6}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, true),
|
|
||||||
array({0, 3, 6}, {3, 1}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, false),
|
|
||||||
array({0, 1, 2}))
|
|
||||||
.item<bool>());
|
|
||||||
CHECK(array_equal(
|
|
||||||
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, true),
|
|
||||||
array({0, 1, 2}, {1, 3}))
|
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user