completed the implementation of the norm

This commit is contained in:
Gabrijel Boduljak 2023-12-21 18:33:23 +01:00 committed by Awni Hannun
parent 05203ecd78
commit 8c43d820d9
4 changed files with 159 additions and 11 deletions

View File

@ -4,7 +4,7 @@
#include <numeric>
#include <set>
#include <sstream>
#include <variant>
#include <string>
#include <vector>
#include "mlx/array.h"
@ -14,10 +14,77 @@
namespace mlx::core::linalg {
inline std::vector<int> get_shape_reducing_over_all_dims(int num_axes) {
std::vector<int> shape(num_axes);
std::iota(shape.begin(), shape.end(), 0);
return shape;
inline array vector_norm(
const array& a,
const double ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (ord == 0.0)
return sum(a != 0, axis, keepdims, s);
else if (ord == 1.0)
return sum(abs(a, s), axis, keepdims, s);
else if (ord == 2.0)
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
else
return power(
sum(power(abs(a, s), array(ord), s), axis, keepdims, s),
array(1.0 / ord));
}
inline array vector_norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (ord == "inf")
return max(abs(a, s), axis, keepdims, s);
else if (ord == "-inf")
return min(abs(a, s), axis, keepdims, s);
std::stringstream error_stream;
error_stream << "Invalid ord value " << ord;
throw std::invalid_argument(error_stream.str());
}
inline array matrix_norm(
const array& a,
const double ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
auto row_axis = axis[0];
auto col_axis = axis[1];
if (!keepdims && col_axis > row_axis)
col_axis -= 1;
if (ord == -1.0)
return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
if (ord == 1.0)
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
if (ord == 2.0 || ord == -2.0)
throw std::logic_error("Singular value norms are not implemented.");
std::stringstream error_stream;
error_stream << "Invalid ord value " << ord << " for matrix norm";
throw std::invalid_argument(error_stream.str());
}
inline array matrix_norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (ord == "f" || ord == "fro")
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
else if (ord == "inf")
return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s);
else if (ord == "-inf")
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
if (ord == "nuc")
throw std::logic_error("Nuclear norm is not implemented.");
std::stringstream error_stream;
error_stream << "Invalid ord value " << ord << " for matrix norm";
throw std::invalid_argument(error_stream.str());
}
array norm(
@ -28,14 +95,65 @@ array norm(
auto num_axes = axis.size();
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
return sqrt(sum(
abs(a, s) * abs(a, s),
num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()),
return sqrt(
sum(abs(a, s) * abs(a, s),
num_axes ? axis
: get_shape_reducing_over_all_axes(a.shape().size()),
keepdims,
s));
s),
s);
std::stringstream error_stream;
error_stream << "Invalid axis values " << axis;
throw std::invalid_argument(error_stream.str());
}
array norm(
const array& a,
const double ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
std::vector<int> ax = axis;
if (axis.empty())
ax = get_shape_reducing_over_all_axes(a.ndim());
else
ax = normalize_axes(ax, a.ndim());
auto num_axes = ax.size();
if (num_axes == 1)
return vector_norm(a, ord, ax, keepdims, s);
else if (num_axes == 2)
return matrix_norm(a, ord, ax, keepdims, s);
std::stringstream error_stream;
error_stream << "Invalid axis values " << ax;
throw std::invalid_argument(error_stream.str());
}
array norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
std::vector<int> ax = axis;
if (axis.empty())
ax = get_shape_reducing_over_all_axes(a.ndim());
else
ax = normalize_axes(ax, a.ndim());
auto num_axes = ax.size();
if (num_axes == 1)
return vector_norm(a, ord, ax, keepdims, s);
else if (num_axes == 2)
return matrix_norm(a, ord, ax, keepdims, s);
std::stringstream error_stream;
error_stream << "Invalid axis values " << ax;
throw std::invalid_argument(error_stream.str());
}
} // namespace mlx::core::linalg

View File

@ -11,6 +11,18 @@
#include "string.h"
namespace mlx::core::linalg {
array norm(
const array& a,
const double ord,
const std::vector<int>& axis = {},
bool keepdims = false,
StreamOrDevice s = {});
array norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis = {},
bool keepdims = false,
StreamOrDevice s = {});
array norm(
const array& a,
const std::vector<int>& axis = {},

View File

@ -1,5 +1,6 @@
// Copyright © 2023 Apple Inc.
#include <numeric>
#include <sstream>
#include <vector>
@ -73,6 +74,12 @@ int normalize_axis(int axis, int ndim) {
}
return axis;
}
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim) {
std::vector<int> canonical;
for (int ax : axes)
canonical.push_back(normalize_axis(ax, ndim));
return canonical;
}
std::ostream& operator<<(std::ostream& os, const Device& d) {
os << "Device(";
@ -279,4 +286,10 @@ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
return os;
}
std::vector<int> get_shape_reducing_over_all_axes(int ndim) {
std::vector<int> shape(ndim);
std::iota(shape.begin(), shape.end(), 0);
return shape;
}
} // namespace mlx::core

View File

@ -24,6 +24,7 @@ bool is_same_shape(const std::vector<array>& arrays);
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
*/
int normalize_axis(int axis, int ndim);
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim);
std::ostream& operator<<(std::ostream& os, const Device& d);
std::ostream& operator<<(std::ostream& os, const Stream& s);
@ -41,4 +42,8 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
return os << static_cast<float>(v);
}
/**
* Returns the axes vector [0, 1, ... ndim).
*/
std::vector<int> get_shape_reducing_over_all_axes(int ndim);
} // namespace mlx::core