mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
completed the implementation of the norm
This commit is contained in:
parent
05203ecd78
commit
8c43d820d9
140
mlx/linalg.cpp
140
mlx/linalg.cpp
@ -4,7 +4,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <variant>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
@ -14,10 +14,77 @@
|
|||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
inline std::vector<int> get_shape_reducing_over_all_dims(int num_axes) {
|
inline array vector_norm(
|
||||||
std::vector<int> shape(num_axes);
|
const array& a,
|
||||||
std::iota(shape.begin(), shape.end(), 0);
|
const double ord,
|
||||||
return shape;
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (ord == 0.0)
|
||||||
|
return sum(a != 0, axis, keepdims, s);
|
||||||
|
else if (ord == 1.0)
|
||||||
|
return sum(abs(a, s), axis, keepdims, s);
|
||||||
|
else if (ord == 2.0)
|
||||||
|
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
||||||
|
else
|
||||||
|
return power(
|
||||||
|
sum(power(abs(a, s), array(ord), s), axis, keepdims, s),
|
||||||
|
array(1.0 / ord));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array vector_norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (ord == "inf")
|
||||||
|
return max(abs(a, s), axis, keepdims, s);
|
||||||
|
else if (ord == "-inf")
|
||||||
|
return min(abs(a, s), axis, keepdims, s);
|
||||||
|
std::stringstream error_stream;
|
||||||
|
error_stream << "Invalid ord value " << ord;
|
||||||
|
throw std::invalid_argument(error_stream.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array matrix_norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto row_axis = axis[0];
|
||||||
|
auto col_axis = axis[1];
|
||||||
|
if (!keepdims && col_axis > row_axis)
|
||||||
|
col_axis -= 1;
|
||||||
|
if (ord == -1.0)
|
||||||
|
return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
||||||
|
if (ord == 1.0)
|
||||||
|
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
||||||
|
if (ord == 2.0 || ord == -2.0)
|
||||||
|
throw std::logic_error("Singular value norms are not implemented.");
|
||||||
|
std::stringstream error_stream;
|
||||||
|
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||||
|
throw std::invalid_argument(error_stream.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array matrix_norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (ord == "f" || ord == "fro")
|
||||||
|
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
||||||
|
else if (ord == "inf")
|
||||||
|
return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s);
|
||||||
|
else if (ord == "-inf")
|
||||||
|
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
|
||||||
|
if (ord == "nuc")
|
||||||
|
throw std::logic_error("Nuclear norm is not implemented.");
|
||||||
|
std::stringstream error_stream;
|
||||||
|
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||||
|
throw std::invalid_argument(error_stream.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
array norm(
|
array norm(
|
||||||
@ -28,14 +95,65 @@ array norm(
|
|||||||
auto num_axes = axis.size();
|
auto num_axes = axis.size();
|
||||||
|
|
||||||
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
||||||
return sqrt(sum(
|
return sqrt(
|
||||||
abs(a, s) * abs(a, s),
|
sum(abs(a, s) * abs(a, s),
|
||||||
num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()),
|
num_axes ? axis
|
||||||
keepdims,
|
: get_shape_reducing_over_all_axes(a.shape().size()),
|
||||||
s));
|
keepdims,
|
||||||
|
s),
|
||||||
|
s);
|
||||||
|
|
||||||
std::stringstream error_stream;
|
std::stringstream error_stream;
|
||||||
error_stream << "Invalid axis values" << axis;
|
error_stream << "Invalid axis values " << axis;
|
||||||
throw std::invalid_argument(error_stream.str());
|
throw std::invalid_argument(error_stream.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
std::vector<int> ax = axis;
|
||||||
|
|
||||||
|
if (axis.empty())
|
||||||
|
ax = get_shape_reducing_over_all_axes(a.ndim());
|
||||||
|
else
|
||||||
|
ax = normalize_axes(ax, a.ndim());
|
||||||
|
|
||||||
|
auto num_axes = ax.size();
|
||||||
|
if (num_axes == 1)
|
||||||
|
return vector_norm(a, ord, ax, keepdims, s);
|
||||||
|
else if (num_axes == 2)
|
||||||
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
|
||||||
|
std::stringstream error_stream;
|
||||||
|
error_stream << "Invalid axis values " << ax;
|
||||||
|
throw std::invalid_argument(error_stream.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
std::vector<int> ax = axis;
|
||||||
|
|
||||||
|
if (axis.empty())
|
||||||
|
ax = get_shape_reducing_over_all_axes(a.ndim());
|
||||||
|
else
|
||||||
|
ax = normalize_axes(ax, a.ndim());
|
||||||
|
|
||||||
|
auto num_axes = ax.size();
|
||||||
|
if (num_axes == 1)
|
||||||
|
return vector_norm(a, ord, ax, keepdims, s);
|
||||||
|
else if (num_axes == 2)
|
||||||
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
|
||||||
|
std::stringstream error_stream;
|
||||||
|
error_stream << "Invalid axis values " << ax;
|
||||||
|
throw std::invalid_argument(error_stream.str());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::linalg
|
} // namespace mlx::core::linalg
|
12
mlx/linalg.h
12
mlx/linalg.h
@ -11,6 +11,18 @@
|
|||||||
#include "string.h"
|
#include "string.h"
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::vector<int>& axis = {},
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::vector<int>& axis = {},
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
array norm(
|
array norm(
|
||||||
const array& a,
|
const array& a,
|
||||||
const std::vector<int>& axis = {},
|
const std::vector<int>& axis = {},
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -73,6 +74,12 @@ int normalize_axis(int axis, int ndim) {
|
|||||||
}
|
}
|
||||||
return axis;
|
return axis;
|
||||||
}
|
}
|
||||||
|
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim) {
|
||||||
|
std::vector<int> canonical;
|
||||||
|
for (int ax : axes)
|
||||||
|
canonical.push_back(normalize_axis(ax, ndim));
|
||||||
|
return canonical;
|
||||||
|
}
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||||
os << "Device(";
|
os << "Device(";
|
||||||
@ -279,4 +286,10 @@ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> get_shape_reducing_over_all_axes(int ndim) {
|
||||||
|
std::vector<int> shape(ndim);
|
||||||
|
std::iota(shape.begin(), shape.end(), 0);
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -24,6 +24,7 @@ bool is_same_shape(const std::vector<array>& arrays);
|
|||||||
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
||||||
*/
|
*/
|
||||||
int normalize_axis(int axis, int ndim);
|
int normalize_axis(int axis, int ndim);
|
||||||
|
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim);
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Device& d);
|
std::ostream& operator<<(std::ostream& os, const Device& d);
|
||||||
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||||
@ -41,4 +42,8 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
|||||||
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
||||||
return os << static_cast<float>(v);
|
return os << static_cast<float>(v);
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* Returns the axes vector [0, 1, ... ndim).
|
||||||
|
*/
|
||||||
|
std::vector<int> get_shape_reducing_over_all_axes(int ndim);
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user