mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
completed the implementation of the norm
This commit is contained in:
parent
05203ecd78
commit
8c43d820d9
136
mlx/linalg.cpp
136
mlx/linalg.cpp
@ -4,7 +4,7 @@
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <variant>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@ -14,10 +14,77 @@
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
inline std::vector<int> get_shape_reducing_over_all_dims(int num_axes) {
|
||||
std::vector<int> shape(num_axes);
|
||||
std::iota(shape.begin(), shape.end(), 0);
|
||||
return shape;
|
||||
inline array vector_norm(
|
||||
const array& a,
|
||||
const double ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (ord == 0.0)
|
||||
return sum(a != 0, axis, keepdims, s);
|
||||
else if (ord == 1.0)
|
||||
return sum(abs(a, s), axis, keepdims, s);
|
||||
else if (ord == 2.0)
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
||||
else
|
||||
return power(
|
||||
sum(power(abs(a, s), array(ord), s), axis, keepdims, s),
|
||||
array(1.0 / ord));
|
||||
}
|
||||
|
||||
inline array vector_norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (ord == "inf")
|
||||
return max(abs(a, s), axis, keepdims, s);
|
||||
else if (ord == "-inf")
|
||||
return min(abs(a, s), axis, keepdims, s);
|
||||
std::stringstream error_stream;
|
||||
error_stream << "Invalid ord value " << ord;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
|
||||
inline array matrix_norm(
|
||||
const array& a,
|
||||
const double ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
auto row_axis = axis[0];
|
||||
auto col_axis = axis[1];
|
||||
if (!keepdims && col_axis > row_axis)
|
||||
col_axis -= 1;
|
||||
if (ord == -1.0)
|
||||
return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
||||
if (ord == 1.0)
|
||||
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
||||
if (ord == 2.0 || ord == -2.0)
|
||||
throw std::logic_error("Singular value norms are not implemented.");
|
||||
std::stringstream error_stream;
|
||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
|
||||
inline array matrix_norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (ord == "f" || ord == "fro")
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
||||
else if (ord == "inf")
|
||||
return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s);
|
||||
else if (ord == "-inf")
|
||||
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
|
||||
if (ord == "nuc")
|
||||
throw std::logic_error("Nuclear norm is not implemented.");
|
||||
std::stringstream error_stream;
|
||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
|
||||
array norm(
|
||||
@ -28,14 +95,65 @@ array norm(
|
||||
auto num_axes = axis.size();
|
||||
|
||||
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
||||
return sqrt(sum(
|
||||
abs(a, s) * abs(a, s),
|
||||
num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()),
|
||||
return sqrt(
|
||||
sum(abs(a, s) * abs(a, s),
|
||||
num_axes ? axis
|
||||
: get_shape_reducing_over_all_axes(a.shape().size()),
|
||||
keepdims,
|
||||
s));
|
||||
s),
|
||||
s);
|
||||
|
||||
std::stringstream error_stream;
|
||||
error_stream << "Invalid axis values " << axis;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
|
||||
array norm(
|
||||
const array& a,
|
||||
const double ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> ax = axis;
|
||||
|
||||
if (axis.empty())
|
||||
ax = get_shape_reducing_over_all_axes(a.ndim());
|
||||
else
|
||||
ax = normalize_axes(ax, a.ndim());
|
||||
|
||||
auto num_axes = ax.size();
|
||||
if (num_axes == 1)
|
||||
return vector_norm(a, ord, ax, keepdims, s);
|
||||
else if (num_axes == 2)
|
||||
return matrix_norm(a, ord, ax, keepdims, s);
|
||||
|
||||
std::stringstream error_stream;
|
||||
error_stream << "Invalid axis values " << ax;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> ax = axis;
|
||||
|
||||
if (axis.empty())
|
||||
ax = get_shape_reducing_over_all_axes(a.ndim());
|
||||
else
|
||||
ax = normalize_axes(ax, a.ndim());
|
||||
|
||||
auto num_axes = ax.size();
|
||||
if (num_axes == 1)
|
||||
return vector_norm(a, ord, ax, keepdims, s);
|
||||
else if (num_axes == 2)
|
||||
return matrix_norm(a, ord, ax, keepdims, s);
|
||||
|
||||
std::stringstream error_stream;
|
||||
error_stream << "Invalid axis values " << ax;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
12
mlx/linalg.h
12
mlx/linalg.h
@ -11,6 +11,18 @@
|
||||
#include "string.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
array norm(
|
||||
const array& a,
|
||||
const double ord,
|
||||
const std::vector<int>& axis = {},
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
const std::vector<int>& axis = {},
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::vector<int>& axis = {},
|
||||
|
@ -1,5 +1,6 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
@ -73,6 +74,12 @@ int normalize_axis(int axis, int ndim) {
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim) {
|
||||
std::vector<int> canonical;
|
||||
for (int ax : axes)
|
||||
canonical.push_back(normalize_axis(ax, ndim));
|
||||
return canonical;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||
os << "Device(";
|
||||
@ -279,4 +286,10 @@ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::vector<int> get_shape_reducing_over_all_axes(int ndim) {
|
||||
std::vector<int> shape(ndim);
|
||||
std::iota(shape.begin(), shape.end(), 0);
|
||||
return shape;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -24,6 +24,7 @@ bool is_same_shape(const std::vector<array>& arrays);
|
||||
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
||||
*/
|
||||
int normalize_axis(int axis, int ndim);
|
||||
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim);
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d);
|
||||
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||
@ -41,4 +42,8 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
||||
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
||||
return os << static_cast<float>(v);
|
||||
}
|
||||
/**
|
||||
* Returns the axes vector [0, 1, ... ndim).
|
||||
*/
|
||||
std::vector<int> get_shape_reducing_over_all_axes(int ndim);
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user