renamed vector_norm to norm, implemented norm without provided ord

This commit is contained in:
Gabrijel Boduljak 2023-12-20 03:13:18 +01:00 committed by Awni Hannun
parent 24da85025f
commit 05203ecd78
3 changed files with 93 additions and 234 deletions

View File

@ -3,62 +3,39 @@
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <set> #include <set>
#include <sstream>
#include <variant> #include <variant>
#include <vector>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/linalg.h" #include "mlx/linalg.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "utils.h"
namespace mlx::core::linalg { namespace mlx::core::linalg {
array vector_norm( inline std::vector<int> get_shape_reducing_over_all_dims(int num_axes) {
std::vector<int> shape(num_axes);
std::iota(shape.begin(), shape.end(), 0);
return shape;
}
array norm(
const array& a, const array& a,
const std::variant<double, std::string>& ord, const std::vector<int>& axis,
const std::vector<int>& axes,
bool keepdims, bool keepdims,
StreamOrDevice s) { StreamOrDevice s) {
return std::visit( auto num_axes = axis.size();
overloaded{
[&](double p) { if (num_axes == 0 || num_axes == 1 || num_axes == 2)
if (p >= 1) return sqrt(sum(
return power( abs(a, s) * abs(a, s),
sum(power(abs(a, s), array(p), s), axes, keepdims, s), num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()),
array(1.0 / p), keepdims,
s); s));
else if (p == 0)
return sum( std::stringstream error_stream;
where(a != 0, array(1), array(0), s), axes, keepdims, s); error_stream << "Invalid axis values" << axis;
else throw std::invalid_argument(error_stream.str());
throw std::invalid_argument(
"[core.linalg.norm] p norm is defined only for p >= 1.");
},
[&](const std::string& norm_type) {
if (norm_type == "inf")
return max(abs(a, s), axes, keepdims, s);
else if (norm_type == "-inf")
return min(abs(a, s), axes, keepdims, s);
else
throw std::invalid_argument(
"[core.linalg.norm] Unsupported norm type for a vector.");
}},
ord);
}
array vector_norm(
const array& a,
const std::variant<double, std::string>& ord,
bool keepdims,
StreamOrDevice s) {
return vector_norm(
reshape(a, {static_cast<int>(a.size())}), ord, {-1}, keepdims, s);
}
array vector_norm(
const array& a,
const std::vector<int>& axes,
bool keepdims,
StreamOrDevice s) {
return vector_norm(a, 2.0, axes, keepdims, s);
}
array vector_norm(const array& a, bool keepdims, StreamOrDevice s) {
return vector_norm(a, 2.0, keepdims, s);
} }
} // namespace mlx::core::linalg } // namespace mlx::core::linalg

View File

@ -11,35 +11,9 @@
#include "string.h" #include "string.h"
namespace mlx::core::linalg { namespace mlx::core::linalg {
array norm(
template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
/*
Computes a vector norm.
If axes = {}, x will be flattened before the norm is computed.
Otherwise, the norm is computed over axes and the other dimensions are
treated as batch dimensions.
*/
array vector_norm(
const array& a, const array& a,
const std::variant<double, std::string>& ord = 2.0, const std::vector<int>& axis = {},
const std::vector<int>& axes = {},
bool keepdims = false, bool keepdims = false,
StreamOrDevice s = {}); StreamOrDevice s = {});
array vector_norm(
const array& a,
const std::variant<double, std::string>& ord = 2.0,
bool keepdims = false,
StreamOrDevice s = {});
array vector_norm(
const array& a,
const std::vector<int>& axes = {},
bool keepdims = false,
StreamOrDevice s = {});
array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {});
} // namespace mlx::core::linalg } // namespace mlx::core::linalg

View File

@ -10,180 +10,88 @@
using namespace mlx::core; using namespace mlx::core;
using namespace mlx::core::linalg; using namespace mlx::core::linalg;
TEST_CASE("vector_norm") { TEST_CASE("[mlx.core.linalg.norm] no ord") {
// Test 1-norm on a vector array arr_one_d({1, 2, 3});
CHECK( array arr_two_d = reshape(arange(9), {3, 3});
array_equal(vector_norm(ones({3}), 1.0, false), array(3.0)).item<bool>()); array arr_three_d = reshape(arange(18), {2, 3, 3});
CHECK(array_equal(vector_norm(ones({3}), 1.0, true), array({3.0}))
.item<bool>()); CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item<bool>());
// Test 1-norm on a matrix CHECK(array_equal(norm(arr_one_d, {0}), array(sqrt(1 + 4 + 9))).item<bool>());
CHECK(array_equal( CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, false), array(36)) norm(arr_two_d),
array(sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, true), array({36})) norm(arr_two_d, {0}),
array(
{sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, false), norm(arr_two_d, {1}),
array(36)) array(
{sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, true), norm(arr_two_d, {0, 1}),
array({36}, {1, 1})) array(sqrt(
.item<bool>()); 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
// Over columns
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, false),
array({3, 12, 21}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, true), norm(arr_three_d, {2}),
array({3, 12, 21}, {3, 1})) array(
{
sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8),
sqrt(9 * 9 + 10 * 10 + 11 * 11),
sqrt(12 * 12 + 13 * 13 + 14 * 14),
sqrt(15 * 15 + 16 * 16 + 17 * 17),
},
{2, 3}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, false), norm(arr_three_d, {1}),
array({3, 12, 21})) array(
{
sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8),
sqrt(9 * 9 + 12 * 12 + 15 * 15),
sqrt(10 * 10 + 13 * 13 + 16 * 16),
sqrt(11 * 11 + 14 * 14 + 17 * 17),
},
{2, 3}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, true), norm(arr_three_d, {0}),
array({3, 12, 21}, {3, 1})) array(
.item<bool>()); {
// Over rows sqrt(0 + 9 * 9),
CHECK(array_equal( sqrt(1 + 10 * 10),
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, false), sqrt(2 * 2 + 11 * 11),
array({9, 12, 15})) sqrt(3 * 3 + 12 * 12),
sqrt(4 * 4 + 13 * 13),
sqrt(5 * 5 + 14 * 14),
sqrt(6 * 6 + 15 * 15),
sqrt(7 * 7 + 16 * 16),
sqrt(8 * 8 + 17 * 17),
},
{3, 3}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, true), norm(arr_three_d, {1, 2}),
array({9, 12, 15}, {1, 3})) array(
.item<bool>()); {sqrt(
CHECK(array_equal( 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, false), 8 * 8),
array({9, 12, 15})) sqrt(
.item<bool>()); 9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
CHECK(array_equal( 15 * 15 + 16 * 16 + 17 * 17)},
vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, true), {2}))
array({9, 12, 15}, {1, 3}))
.item<bool>());
// Test 1-norm on a 3d tensor
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, false), array(153))
.item<bool>());
CHECK(
array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, true), array({153}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, false),
array(153))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, true),
array({153}, {1, 1, 1}))
.item<bool>());
// Over last axis
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, false),
array({3, 12, 21, 30, 39, 48}, {2, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, true),
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, false),
array({3, 12, 21, 30, 39, 48}, {2, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, true),
array({3, 12, 21, 30, 39, 48}, {2, 3, 1}))
.item<bool>());
// Over middle axis
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, false),
array({9, 12, 15, 36, 39, 42}, {2, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, true),
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, false),
array({9, 12, 15, 36, 39, 42}, {2, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, true),
array({9, 12, 15, 36, 39, 42}, {2, 1, 3}))
.item<bool>());
// Over the first axis
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, false),
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, true),
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, false),
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, true),
array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3}))
.item<bool>());
// Test 2-norm on a vector
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, false), array(5.0))
.item<bool>());
CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, true), array({5.0}))
.item<bool>());
// Test that 2 is default ord
CHECK(array_equal(vector_norm({3.0, 4.0}, false), array(5.0)).item<bool>());
CHECK(array_equal(vector_norm({3.0, 4.0}, true), array({5.0})).item<bool>());
// Test "inf" norm on a matrix
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", false), array(8.0))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", true), array({8.0}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, false),
array({2, 5, 8}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, true),
array({2, 5, 8}, {3, 1}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, false),
array({6.0, 7.0, 8.0}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, true),
array({6, 7, 8}, {1, 3}))
.item<bool>());
// Test "-inf" norm on a matrix
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", false), array(0))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", true), array({0}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, false),
array({0, 3, 6}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, true),
array({0, 3, 6}, {3, 1}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, false),
array({0, 1, 2}))
.item<bool>());
CHECK(array_equal(
vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, true),
array({0, 1, 2}, {1, 3}))
.item<bool>()); .item<bool>());
} }