mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
some style and API consistency updates to linalg norm
This commit is contained in:
parent
4bae4a8239
commit
f7cea9563d
@ -57,6 +57,7 @@ are the CPU and GPU.
|
||||
python/random
|
||||
python/transforms
|
||||
python/fft
|
||||
python/linalg
|
||||
python/nn
|
||||
python/optimizers
|
||||
python/tree_utils
|
||||
|
@ -1,11 +1,11 @@
|
||||
.. _linalg:
|
||||
|
||||
Linear Algebra
|
||||
=====
|
||||
==============
|
||||
|
||||
.. currentmodule:: mlx.core.linalg
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
norm
|
||||
norm
|
||||
|
196
mlx/linalg.cpp
196
mlx/linalg.cpp
@ -1,47 +1,42 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
inline array vector_norm(
|
||||
const array& a,
|
||||
const double ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (ord == 0.0)
|
||||
return sum(a != 0, axis, keepdims, s);
|
||||
else if (ord == 1.0)
|
||||
return sum(abs(a, s), axis, keepdims, s);
|
||||
else if (ord == 2.0)
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
||||
else
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
if (ord == 0.0) {
|
||||
return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);
|
||||
} else if (ord == 1.0) {
|
||||
return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
|
||||
} else if (ord == 2.0) {
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
} else if (ord == std::numeric_limits<double>::infinity()) {
|
||||
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
|
||||
} else if (ord == -std::numeric_limits<double>::infinity()) {
|
||||
return astype(min(abs(a, s), axis, keepdims, s), dtype, s);
|
||||
} else {
|
||||
return power(
|
||||
sum(power(abs(a, s), array(ord), s), axis, keepdims, s),
|
||||
array(1.0 / ord));
|
||||
}
|
||||
|
||||
inline array vector_norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (ord == "inf")
|
||||
return max(abs(a, s), axis, keepdims, s);
|
||||
else if (ord == "-inf")
|
||||
return min(abs(a, s), axis, keepdims, s);
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid ord value " << ord;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),
|
||||
array(1.0 / ord, dtype),
|
||||
s);
|
||||
}
|
||||
}
|
||||
|
||||
inline array matrix_norm(
|
||||
@ -50,19 +45,30 @@ inline array matrix_norm(
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
auto dtype = at_least_float(a.dtype());
|
||||
auto row_axis = axis[0];
|
||||
auto col_axis = axis[1];
|
||||
if (!keepdims && col_axis > row_axis)
|
||||
if (!keepdims && col_axis > row_axis && col_axis > 0) {
|
||||
col_axis -= 1;
|
||||
if (ord == -1.0)
|
||||
return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
||||
if (ord == 1.0)
|
||||
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
||||
if (ord == 2.0 || ord == -2.0)
|
||||
throw std::logic_error("Singular value norms are not implemented.");
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
if (ord == -1.0) {
|
||||
return astype(
|
||||
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||
dtype,
|
||||
s);
|
||||
} else if (ord == 1.0) {
|
||||
return astype(
|
||||
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||
dtype,
|
||||
s);
|
||||
} else if (ord == 2.0 || ord == -2.0) {
|
||||
throw std::runtime_error(
|
||||
"[linalg::norm] Singular value norms are not implemented.");
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
inline array matrix_norm(
|
||||
@ -71,85 +77,77 @@ inline array matrix_norm(
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (ord == "f" || ord == "fro")
|
||||
if (ord == "f" || ord == "fro") {
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
|
||||
else if (ord == "inf")
|
||||
return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s);
|
||||
else if (ord == "-inf")
|
||||
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
|
||||
if (ord == "nuc")
|
||||
throw std::logic_error("Nuclear norm is not implemented.");
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
} else if (ord == "nuc") {
|
||||
throw std::runtime_error(
|
||||
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
auto num_axes = axis.size();
|
||||
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!axis) {
|
||||
return norm(flatten(a, s), std::vector<int>{0}, keepdims, s);
|
||||
}
|
||||
|
||||
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
|
||||
return sqrt(
|
||||
sum(abs(a, s) * abs(a, s),
|
||||
num_axes ? axis : get_reduce_axes({}, a.ndim()),
|
||||
keepdims,
|
||||
s),
|
||||
s);
|
||||
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid axis values " << axis;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
if (axis.value().size() > 2) {
|
||||
throw std::invalid_argument(
|
||||
"[linalg::norm] Received too many axes for norm");
|
||||
}
|
||||
return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
|
||||
}
|
||||
|
||||
array norm(
|
||||
const array& a,
|
||||
const double ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> ax = axis;
|
||||
|
||||
if (axis.empty())
|
||||
ax = get_reduce_axes({}, a.ndim());
|
||||
else
|
||||
ax = normalize_axes(ax, a.ndim());
|
||||
|
||||
auto num_axes = ax.size();
|
||||
if (num_axes == 1)
|
||||
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<int> ax;
|
||||
if (!axis) {
|
||||
ax.resize(a.ndim());
|
||||
std::iota(ax.begin(), ax.end(), 0);
|
||||
} else {
|
||||
ax = axis.value();
|
||||
}
|
||||
if (ax.size() == 1) {
|
||||
return vector_norm(a, ord, ax, keepdims, s);
|
||||
else if (num_axes == 2)
|
||||
} else if (ax.size() == 2) {
|
||||
return matrix_norm(a, ord, ax, keepdims, s);
|
||||
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid axis values " << ax;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"[linalg::norm] Received too many axes for norm");
|
||||
}
|
||||
}
|
||||
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
std::vector<int> ax = axis;
|
||||
|
||||
if (axis.empty())
|
||||
ax = get_reduce_axes({}, a.ndim());
|
||||
else
|
||||
ax = normalize_axes(ax, a.ndim());
|
||||
|
||||
auto num_axes = ax.size();
|
||||
if (num_axes == 1)
|
||||
return vector_norm(a, ord, ax, keepdims, s);
|
||||
else if (num_axes == 2)
|
||||
return matrix_norm(a, ord, ax, keepdims, s);
|
||||
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid axis values " << ax;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||
bool keepdims /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
std::vector<int> ax;
|
||||
if (!axis) {
|
||||
ax.resize(a.ndim());
|
||||
std::iota(ax.begin(), ax.end(), 0);
|
||||
} else {
|
||||
ax = axis.value();
|
||||
}
|
||||
if (ax.size() != 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices,"
|
||||
<< " but received " << ax.size() << " axis/axes.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
return matrix_norm(a, ord, ax, keepdims, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
} // namespace mlx::core::linalg
|
||||
|
42
mlx/linalg.h
42
mlx/linalg.h
@ -2,27 +2,61 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "ops.h"
|
||||
#include "stream.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
/*
|
||||
* Compute vector or matrix norms.
|
||||
*
|
||||
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
||||
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
||||
* - If axis is provided, but ord is not, then the 2-norm is computed along the
|
||||
* given axes. At most 2 axes can be specified.
|
||||
* - If both axis and ord are provided, then the corresponding matrix of vector
|
||||
* norm is computed. At most 2 axes can be specified.
|
||||
*/
|
||||
array norm(
|
||||
const array& a,
|
||||
const double ord,
|
||||
const std::vector<int>& axis = {},
|
||||
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
inline array norm(
|
||||
const array& a,
|
||||
const double ord,
|
||||
int axis,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {}) {
|
||||
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
||||
}
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
const std::vector<int>& axis = {},
|
||||
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
inline array norm(
|
||||
const array& a,
|
||||
const std::string& ord,
|
||||
int axis,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {}) {
|
||||
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
||||
}
|
||||
array norm(
|
||||
const array& a,
|
||||
const std::vector<int>& axis = {},
|
||||
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||
bool keepdims = false,
|
||||
StreamOrDevice s = {});
|
||||
} // namespace mlx::core::linalg
|
||||
inline array
|
||||
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
|
||||
return norm(a, std::vector<int>{axis}, keepdims, s);
|
||||
}
|
||||
|
||||
} // namespace mlx::core::linalg
|
||||
|
@ -1,6 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
@ -74,12 +73,6 @@ int normalize_axis(int axis, int ndim) {
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim) {
|
||||
std::vector<int> canonical;
|
||||
for (int ax : axes)
|
||||
canonical.push_back(normalize_axis(ax, ndim));
|
||||
return canonical;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d) {
|
||||
os << "Device(";
|
||||
|
16
mlx/utils.h
16
mlx/utils.h
@ -2,7 +2,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <numeric>
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "dtype.h"
|
||||
@ -25,7 +24,6 @@ bool is_same_shape(const std::vector<array>& arrays);
|
||||
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
|
||||
*/
|
||||
int normalize_axis(int axis, int ndim);
|
||||
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim);
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Device& d);
|
||||
std::ostream& operator<<(std::ostream& os, const Stream& s);
|
||||
@ -43,18 +41,4 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
|
||||
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
|
||||
return os << static_cast<float>(v);
|
||||
}
|
||||
|
||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
std::vector<int> axes;
|
||||
if (std::holds_alternative<std::monostate>(v)) {
|
||||
axes.resize(dims);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
} else if (auto pv = std::get_if<int>(&v); pv) {
|
||||
axes.push_back(*pv);
|
||||
} else {
|
||||
axes = std::get<std::vector<int>>(v);
|
||||
}
|
||||
return axes;
|
||||
}
|
||||
} // namespace mlx::core
|
||||
|
@ -7,7 +7,6 @@
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
|
@ -26,193 +26,164 @@ using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
void init_linalg(py::module_& parent_module) {
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
auto m =
|
||||
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
|
||||
|
||||
m.def(
|
||||
"norm",
|
||||
[](const array& a,
|
||||
const std::variant<std::monostate, int, double, std::string>& ord,
|
||||
const std::variant<std::monostate, int, std::vector<int>>& axis,
|
||||
const std::variant<std::monostate, int, double, std::string>& ord_,
|
||||
const std::variant<std::monostate, int, std::vector<int>>& axis_,
|
||||
const bool keepdims,
|
||||
const StreamOrDevice stream) {
|
||||
return std::visit(
|
||||
overloaded{
|
||||
[&](const double p) {
|
||||
if (std::isinf((float)p) || std::isinf(p)) {
|
||||
if (p > 0) {
|
||||
return norm(
|
||||
a,
|
||||
"inf",
|
||||
std::holds_alternative<std::monostate>(axis)
|
||||
? std::vector<int>()
|
||||
: get_reduce_axes(axis, a.ndim()),
|
||||
keepdims,
|
||||
stream);
|
||||
}
|
||||
return norm(
|
||||
a,
|
||||
"-inf",
|
||||
get_reduce_axes(axis, a.ndim()),
|
||||
keepdims,
|
||||
stream);
|
||||
}
|
||||
return norm(
|
||||
a,
|
||||
p,
|
||||
std::holds_alternative<std::monostate>(axis)
|
||||
? std::vector<int>()
|
||||
: get_reduce_axes(axis, a.ndim()),
|
||||
keepdims,
|
||||
stream);
|
||||
},
|
||||
[&](const std::string& p) {
|
||||
return norm(
|
||||
a,
|
||||
p,
|
||||
std::holds_alternative<std::monostate>(axis)
|
||||
? std::vector<int>()
|
||||
: get_reduce_axes(axis, a.ndim()),
|
||||
keepdims,
|
||||
stream);
|
||||
},
|
||||
[&](const std::monostate _) {
|
||||
return norm(
|
||||
a,
|
||||
std::holds_alternative<std::monostate>(axis)
|
||||
? std::vector<int>()
|
||||
: get_reduce_axes(axis, a.ndim()),
|
||||
keepdims,
|
||||
stream);
|
||||
}},
|
||||
ord);
|
||||
std::optional<std::vector<int>> axis = std::nullopt;
|
||||
if (auto pv = std::get_if<int>(&axis_); pv) {
|
||||
axis = std::vector<int>{*pv};
|
||||
} else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {
|
||||
axis = *pv;
|
||||
}
|
||||
|
||||
if (std::holds_alternative<std::monostate>(ord_)) {
|
||||
return norm(a, axis, keepdims, stream);
|
||||
} else {
|
||||
if (auto pv = std::get_if<std::string>(&ord_); pv) {
|
||||
return norm(a, *pv, axis, keepdims, stream);
|
||||
}
|
||||
double ord;
|
||||
if (auto pv = std::get_if<int>(&ord_); pv) {
|
||||
ord = *pv;
|
||||
} else {
|
||||
ord = std::get<double>(ord_);
|
||||
}
|
||||
return norm(a, ord, axis, keepdims, stream);
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
"ord"_a = none,
|
||||
"axis"_a = none,
|
||||
"keepdims"_a = false,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
Matrix or vector norm.
|
||||
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
This function is able to return matrix or vector norms,
|
||||
depending on the value of the ``ord`` parameter.
|
||||
Matrix or vector norm.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : array_like
|
||||
Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord`
|
||||
is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned.
|
||||
ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional
|
||||
Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None.
|
||||
axis : {None, int, 2-tuple of ints}, optional.
|
||||
If `axis` is an integer, it specifies the axis of `a` along which to
|
||||
compute the vector norms. If `axis` is a 2-tuple, it specifies the
|
||||
axes that hold 2-D matrices, and the matrix norms of these matrices
|
||||
are computed. If `axis` is None then either a vector norm (when `a`
|
||||
is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default
|
||||
is None.
|
||||
keepdims : bool, optional
|
||||
If this is set to True, the axes which are normed over are left in the
|
||||
result as dimensions with size one. With this option the result will
|
||||
broadcast correctly against the original `a`.
|
||||
This function computes vector or matrix norms depending on the value of
|
||||
the ``ord`` and ``axis`` parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
n : array
|
||||
Norm of the matrix or vector(s).
|
||||
Args:
|
||||
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
|
||||
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
|
||||
2-norm of ``a.flatten`` will be returned.
|
||||
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
|
||||
If ``None``, the 2-norm will be computed along the given ``axis``.
|
||||
Default: ``None``.
|
||||
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
|
||||
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
|
||||
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
|
||||
norms of these matrices are computed. If `axis` is ``None`` then
|
||||
either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is
|
||||
2-D) is returned. Default: ``None``.
|
||||
keepdims (bool, optional): If ``True``, the axes which are normed over are
|
||||
left in the result as dimensions with size one. Default ``False``.
|
||||
|
||||
Notes
|
||||
-----
|
||||
For values of ``ord < 1``, the result is, strictly speaking, not a
|
||||
mathematical 'norm', but it may still be useful for various numerical
|
||||
purposes.
|
||||
Returns:
|
||||
array: The output containing the norm(s).
|
||||
|
||||
The following norms can be calculated:
|
||||
Notes:
|
||||
For values of ``ord < 1``, the result is, strictly speaking, not a
|
||||
mathematical norm, but it may still be useful for various numerical
|
||||
purposes.
|
||||
|
||||
===== ============================ ==========================
|
||||
ord norm for matrices norm for vectors
|
||||
===== ============================ ==========================
|
||||
None Frobenius norm 2-norm
|
||||
'fro' Frobenius norm --
|
||||
inf max(sum(abs(x), axis=1)) max(abs(x))
|
||||
-inf min(sum(abs(x), axis=1)) min(abs(x))
|
||||
0 -- sum(x != 0)
|
||||
1 max(sum(abs(x), axis=0)) as below
|
||||
-1 min(sum(abs(x), axis=0)) as below
|
||||
2 2-norm (largest sing. value) as below
|
||||
-2 smallest singular value as below
|
||||
other -- sum(abs(x)**ord)**(1./ord)
|
||||
===== ============================ ==========================
|
||||
The following norms can be calculated:
|
||||
|
||||
Nuclear norm and norms based on singular values are not yet implemented.
|
||||
===== ============================ ==========================
|
||||
ord norm for matrices norm for vectors
|
||||
===== ============================ ==========================
|
||||
None Frobenius norm 2-norm
|
||||
'fro' Frobenius norm --
|
||||
inf max(sum(abs(x), axis=1)) max(abs(x))
|
||||
-inf min(sum(abs(x), axis=1)) min(abs(x))
|
||||
0 -- sum(x != 0)
|
||||
1 max(sum(abs(x), axis=0)) as below
|
||||
-1 min(sum(abs(x), axis=0)) as below
|
||||
2 2-norm (largest sing. value) as below
|
||||
-2 smallest singular value as below
|
||||
other -- sum(abs(x)**ord)**(1./ord)
|
||||
===== ============================ ==========================
|
||||
|
||||
The Frobenius norm is given by [1]_:
|
||||
.. warning::
|
||||
Nuclear norm and norms based on singular values are not yet implemented.
|
||||
|
||||
:math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
||||
The Frobenius norm is given by [1]_:
|
||||
|
||||
The nuclear norm is the sum of the singular values.
|
||||
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
||||
|
||||
Both the Frobenius and nuclear norm orders are only defined for
|
||||
matrices and raise a ValueError when ``a.ndim != 2``.
|
||||
The nuclear norm is the sum of the singular values.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
|
||||
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
|
||||
Both the Frobenius and nuclear norm orders are only defined for
|
||||
matrices and raise a ``ValueError`` when ``a.ndim != 2``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import mlx.core as mx
|
||||
>>> from mlx.core import linalg as LA
|
||||
>>> a = mx.arange(9) - 4
|
||||
>>> a
|
||||
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
|
||||
>>> b = a.reshape((3,3))
|
||||
>>> b
|
||||
array([[-4, -3, -2],
|
||||
[-1, 0, 1],
|
||||
[ 2, 3, 4]], dtype=int32)
|
||||
>>> LA.norm(a)
|
||||
array(7.74597, dtype=float32)
|
||||
>>> LA.norm(b)
|
||||
array(7.74597, dtype=float32)
|
||||
>>> LA.norm(b, 'fro')
|
||||
array(7.74597, dtype=float32)
|
||||
>>> LA.norm(a, float("inf"))
|
||||
array(4, dtype=int32)
|
||||
>>> LA.norm(b, float("inf"))
|
||||
array(9, dtype=int32)
|
||||
>>> LA.norm(a, -float("inf"))
|
||||
array(0, dtype=int32)
|
||||
>>> LA.norm(b, -float("inf"))
|
||||
array(2, dtype=int32)
|
||||
>>> LA.norm(a, 1)
|
||||
array(20, dtype=int32)
|
||||
>>> LA.norm(b, 1)
|
||||
array(7, dtype=int32)
|
||||
>>> LA.norm(a, -1)
|
||||
array(0, dtype=float32)
|
||||
>>> LA.norm(b, -1)
|
||||
array(6, dtype=int32)
|
||||
>>> LA.norm(a, 2)
|
||||
array(7.74597, dtype=float32)
|
||||
>>> LA.norm(a, 3)
|
||||
array(5.84804, dtype=float32)
|
||||
>>> LA.norm(a, -3)
|
||||
array(0, dtype=float32)
|
||||
>>> c = mx.array([[ 1, 2, 3],
|
||||
... [-1, 1, 4]])
|
||||
>>> LA.norm(c, axis=0)
|
||||
array([1.41421, 2.23607, 5], dtype=float32)
|
||||
>>> LA.norm(c, axis=1)
|
||||
array([3.74166, 4.24264], dtype=float32)
|
||||
>>> LA.norm(c, ord=1, axis=1)
|
||||
array([6, 6], dtype=int32)
|
||||
>>> m = mx.arange(8).reshape(2,2,2)
|
||||
>>> LA.norm(m, axis=(1,2))
|
||||
array([3.74166, 11.225], dtype=float32)
|
||||
>>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :])
|
||||
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
|
||||
)pbdoc");
|
||||
References:
|
||||
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
|
||||
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
|
||||
|
||||
Examples:
|
||||
>>> import mlx.core as mx
|
||||
>>> from mlx.core import linalg as la
|
||||
>>> a = mx.arange(9) - 4
|
||||
>>> a
|
||||
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
|
||||
>>> b = a.reshape((3,3))
|
||||
>>> b
|
||||
array([[-4, -3, -2],
|
||||
[-1, 0, 1],
|
||||
[ 2, 3, 4]], dtype=int32)
|
||||
>>> la.norm(a)
|
||||
array(7.74597, dtype=float32)
|
||||
>>> la.norm(b)
|
||||
array(7.74597, dtype=float32)
|
||||
>>> la.norm(b, 'fro')
|
||||
array(7.74597, dtype=float32)
|
||||
>>> la.norm(a, float("inf"))
|
||||
array(4, dtype=float32)
|
||||
>>> la.norm(b, float("inf"))
|
||||
array(9, dtype=float32)
|
||||
>>> la.norm(a, -float("inf"))
|
||||
array(0, dtype=float32)
|
||||
>>> la.norm(b, -float("inf"))
|
||||
array(2, dtype=float32)
|
||||
>>> la.norm(a, 1)
|
||||
array(20, dtype=float32)
|
||||
>>> la.norm(b, 1)
|
||||
array(7, dtype=float32)
|
||||
>>> la.norm(a, -1)
|
||||
array(0, dtype=float32)
|
||||
>>> la.norm(b, -1)
|
||||
array(6, dtype=float32)
|
||||
>>> la.norm(a, 2)
|
||||
array(7.74597, dtype=float32)
|
||||
>>> la.norm(a, 3)
|
||||
array(5.84804, dtype=float32)
|
||||
>>> la.norm(a, -3)
|
||||
array(0, dtype=float32)
|
||||
>>> c = mx.array([[ 1, 2, 3],
|
||||
... [-1, 1, 4]])
|
||||
>>> la.norm(c, axis=0)
|
||||
array([1.41421, 2.23607, 5], dtype=float32)
|
||||
>>> la.norm(c, axis=1)
|
||||
array([3.74166, 4.24264], dtype=float32)
|
||||
>>> la.norm(c, ord=1, axis=1)
|
||||
array([6, 6], dtype=float32)
|
||||
>>> m = mx.arange(8).reshape(2,2,2)
|
||||
>>> la.norm(m, axis=(1,2))
|
||||
array([3.74166, 11.225], dtype=float32)
|
||||
>>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
|
||||
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
|
||||
)pbdoc");
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
#include <variant>
|
||||
|
||||
#include <pybind11/complex.h>
|
||||
@ -13,10 +14,24 @@ namespace py = pybind11;
|
||||
|
||||
using namespace mlx::core;
|
||||
|
||||
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
||||
using ScalarOrArray = std::
|
||||
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
|
||||
static constexpr std::monostate none{};
|
||||
|
||||
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
||||
std::vector<int> axes;
|
||||
if (std::holds_alternative<std::monostate>(v)) {
|
||||
axes.resize(dims);
|
||||
std::iota(axes.begin(), axes.end(), 0);
|
||||
} else if (auto pv = std::get_if<int>(&v); pv) {
|
||||
axes.push_back(*pv);
|
||||
} else {
|
||||
axes = std::get<std::vector<int>>(v);
|
||||
}
|
||||
return axes;
|
||||
}
|
||||
|
||||
inline array to_array_with_accessor(py::object obj) {
|
||||
if (py::hasattr(obj, "__mlx_array__")) {
|
||||
return obj.attr("__mlx_array__")().cast<array>();
|
||||
|
@ -3,170 +3,155 @@
|
||||
#include "doctest/doctest.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include "mlx/linalg.h"
|
||||
|
||||
#include "mlx/mlx.h"
|
||||
|
||||
using namespace mlx::core;
|
||||
using namespace mlx::core::linalg;
|
||||
|
||||
TEST_CASE("[mlx.core.linalg.norm] no ord") {
|
||||
array arr_one_d({1, 2, 3});
|
||||
array arr_two_d = reshape(arange(9), {3, 3});
|
||||
array arr_three_d = reshape(arange(18), {2, 3, 3});
|
||||
// Zero dimensions
|
||||
array x(2.0);
|
||||
CHECK_EQ(norm(x).item<float>(), 2.0f);
|
||||
CHECK_THROWS(norm(x, 0));
|
||||
|
||||
CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item<bool>());
|
||||
CHECK(array_equal(norm(arr_one_d, {0}, false), array(sqrt(1 + 4 + 9)))
|
||||
.item<bool>());
|
||||
x = array({1, 2, 3});
|
||||
float expected = std::sqrt(1 + 4 + 9);
|
||||
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||
CHECK_EQ(norm(x, 0, false).item<float>(), doctest::Approx(expected));
|
||||
CHECK_EQ(norm(x, -1, false).item<float>(), doctest::Approx(expected));
|
||||
CHECK_EQ(norm(x, -1, true).ndim(), 1);
|
||||
CHECK_THROWS(norm(x, 1));
|
||||
|
||||
x = reshape(arange(9), {3, 3});
|
||||
expected =
|
||||
std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8);
|
||||
|
||||
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||
CHECK_EQ(
|
||||
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
|
||||
CHECK(array_equal(
|
||||
norm(arr_two_d, {}, false),
|
||||
array(sqrt(
|
||||
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_two_d, {0}, false),
|
||||
norm(x, 0, false),
|
||||
array(
|
||||
{sqrt(0 + 3 * 3 + 6 * 6),
|
||||
sqrt(1 + 4 * 4 + 7 * 7),
|
||||
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||
std::sqrt(1 + 4 * 4 + 7 * 7),
|
||||
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_two_d, {1}, false),
|
||||
CHECK(allclose(
|
||||
norm(x, 1, false),
|
||||
array(
|
||||
{sqrt(0 + 1 + 2 * 2),
|
||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||
{std::sqrt(0 + 1 + 2 * 2),
|
||||
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||
std::sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_two_d, {0, 1}, false),
|
||||
array(sqrt(
|
||||
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_three_d, {2}, false),
|
||||
|
||||
x = reshape(arange(18), {2, 3, 3});
|
||||
CHECK(allclose(
|
||||
norm(x, 2, false),
|
||||
array(
|
||||
{
|
||||
sqrt(0 + 1 + 2 * 2),
|
||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||
sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
||||
sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
||||
sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
||||
sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
||||
std::sqrt(0 + 1 + 2 * 2),
|
||||
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||
std::sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
||||
std::sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
||||
std::sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
||||
std::sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
||||
},
|
||||
{2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_three_d, {1}, false),
|
||||
CHECK(allclose(
|
||||
norm(x, std::vector<int>{1, 2}, false),
|
||||
array(
|
||||
{
|
||||
sqrt(0 + 3 * 3 + 6 * 6),
|
||||
sqrt(1 + 4 * 4 + 7 * 7),
|
||||
sqrt(2 * 2 + 5 * 5 + 8 * 8),
|
||||
sqrt(9 * 9 + 12 * 12 + 15 * 15),
|
||||
sqrt(10 * 10 + 13 * 13 + 16 * 16),
|
||||
sqrt(11 * 11 + 14 * 14 + 17 * 17),
|
||||
},
|
||||
{2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_three_d, {0}, false),
|
||||
array(
|
||||
{
|
||||
sqrt(0 + 9 * 9),
|
||||
sqrt(1 + 10 * 10),
|
||||
sqrt(2 * 2 + 11 * 11),
|
||||
sqrt(3 * 3 + 12 * 12),
|
||||
sqrt(4 * 4 + 13 * 13),
|
||||
sqrt(5 * 5 + 14 * 14),
|
||||
sqrt(6 * 6 + 15 * 15),
|
||||
sqrt(7 * 7 + 16 * 16),
|
||||
sqrt(8 * 8 + 17 * 17),
|
||||
},
|
||||
{3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_three_d, {1, 2}, false),
|
||||
array(
|
||||
{sqrt(
|
||||
{std::sqrt(
|
||||
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
|
||||
8 * 8),
|
||||
sqrt(
|
||||
std::sqrt(
|
||||
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
|
||||
15 * 15 + 16 * 16 + 17 * 17)},
|
||||
{2}))
|
||||
.item<bool>());
|
||||
CHECK_THROWS(norm(x, std::vector<int>{0, 1, 2}));
|
||||
}
|
||||
|
||||
TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
array arr_one_d({1, 2, 3});
|
||||
array arr_two_d = reshape(arange(9), {3, 3});
|
||||
array arr_three_d = reshape(arange(18), {2, 3, 3});
|
||||
CHECK_THROWS(norm(array(0), 2.0));
|
||||
|
||||
CHECK(array_equal(norm(arr_one_d, 2.0), array(sqrt(1 + 4 + 9))).item<bool>());
|
||||
CHECK(array_equal(norm(arr_one_d, 1.0), array(1 + 2 + 3)).item<bool>());
|
||||
CHECK(array_equal(norm(arr_one_d, 0.0), array(3)).item<bool>());
|
||||
array x({1, 2, 3});
|
||||
|
||||
CHECK(array_equal(norm(arr_one_d, 2.0, {0}, false), array(sqrt(1 + 4 + 9)))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_two_d, 2.0, {0}, false),
|
||||
float expected = std::sqrt(1 + 4 + 9);
|
||||
CHECK_EQ(norm(x, 2.0).item<float>(), doctest::Approx(expected));
|
||||
CHECK_EQ(norm(x, 2.0, 0).item<float>(), doctest::Approx(expected));
|
||||
CHECK_THROWS(norm(x, 2.0, 1));
|
||||
|
||||
expected = 1 + 2 + 3;
|
||||
CHECK_EQ(norm(x, 1.0).item<float>(), doctest::Approx(expected));
|
||||
|
||||
expected = 3;
|
||||
CHECK_EQ(norm(x, 0.0).item<float>(), doctest::Approx(expected));
|
||||
|
||||
expected = 3;
|
||||
CHECK_EQ(
|
||||
norm(x, std::numeric_limits<double>::infinity()).item<float>(),
|
||||
doctest::Approx(expected));
|
||||
|
||||
expected = 1;
|
||||
CHECK_EQ(
|
||||
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
|
||||
doctest::Approx(expected));
|
||||
|
||||
x = reshape(arange(9), {3, 3});
|
||||
|
||||
CHECK(allclose(
|
||||
norm(x, 2.0, 0, false),
|
||||
array(
|
||||
{sqrt(0 + 3 * 3 + 6 * 6),
|
||||
sqrt(1 + 4 * 4 + 7 * 7),
|
||||
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||
std::sqrt(1 + 4 * 4 + 7 * 7),
|
||||
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_two_d, 2.0, {1}, false),
|
||||
CHECK(allclose(
|
||||
norm(x, 2.0, 1, false),
|
||||
array(
|
||||
{sqrt(0 + 1 + 2 * 2),
|
||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_three_d, 2.0, {2}, false),
|
||||
array(
|
||||
{
|
||||
sqrt(0 + 1 + 2 * 2),
|
||||
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||
sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
||||
sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
||||
sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
||||
sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
||||
},
|
||||
{2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_three_d, 2.0, {1}, false),
|
||||
array(
|
||||
{
|
||||
sqrt(0 + 3 * 3 + 6 * 6),
|
||||
sqrt(1 + 4 * 4 + 7 * 7),
|
||||
sqrt(2 * 2 + 5 * 5 + 8 * 8),
|
||||
sqrt(9 * 9 + 12 * 12 + 15 * 15),
|
||||
sqrt(10 * 10 + 13 * 13 + 16 * 16),
|
||||
sqrt(11 * 11 + 14 * 14 + 17 * 17),
|
||||
},
|
||||
{2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(
|
||||
norm(arr_three_d, 2.0, {0}, false),
|
||||
array(
|
||||
{
|
||||
sqrt(0 + 9 * 9),
|
||||
sqrt(1 + 10 * 10),
|
||||
sqrt(2 * 2 + 11 * 11),
|
||||
sqrt(3 * 3 + 12 * 12),
|
||||
sqrt(4 * 4 + 13 * 13),
|
||||
sqrt(5 * 5 + 14 * 14),
|
||||
sqrt(6 * 6 + 15 * 15),
|
||||
sqrt(7 * 7 + 16 * 16),
|
||||
sqrt(8 * 8 + 17 * 17),
|
||||
},
|
||||
{3, 3}))
|
||||
.item<bool>());
|
||||
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{0, 1}).item<float>(),
|
||||
doctest::Approx(15.0));
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{1, 0}).item<float>(),
|
||||
doctest::Approx(21.0));
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{0, 1}).item<float>(),
|
||||
doctest::Approx(9.0));
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
||||
doctest::Approx(3.0));
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
|
||||
std::vector<int>{1, 1});
|
||||
|
||||
CHECK_EQ(
|
||||
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||
doctest::Approx(9.0));
|
||||
CHECK_EQ(
|
||||
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||
doctest::Approx(15.0));
|
||||
|
||||
x = reshape(arange(18), {2, 3, 3});
|
||||
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, 3.0, {0}),
|
||||
norm(x, 3.0, 0),
|
||||
array(
|
||||
{9.,
|
||||
10.00333222,
|
||||
@ -179,15 +164,8 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
17.57113899},
|
||||
{3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(
|
||||
allclose(
|
||||
norm(arr_three_d, 3.0, {1}),
|
||||
array(
|
||||
{6.24025147, 7.41685954, 8.6401226, 18., 19.39257164, 20.7915893},
|
||||
{2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, 3.0, {2}),
|
||||
norm(x, 3.0, 2),
|
||||
array(
|
||||
{2.08008382,
|
||||
6.,
|
||||
@ -197,110 +175,76 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||
23.13593104},
|
||||
{2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, 0.0, {0}),
|
||||
array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
|
||||
CHECK(
|
||||
allclose(
|
||||
norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(
|
||||
allclose(
|
||||
norm(arr_three_d, 0.0, {1}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(
|
||||
allclose(
|
||||
norm(arr_three_d, 0.0, {2}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, 1.0, {0}),
|
||||
norm(x, 1.0, 0),
|
||||
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, 1.0, {1}),
|
||||
array({9., 12., 15., 36., 39., 42.}, {2, 3}))
|
||||
CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, 1.0, {2}),
|
||||
array({3., 12., 21., 30., 39., 48.}, {2, 3}))
|
||||
CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3}))
|
||||
.item<bool>());
|
||||
|
||||
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}), array({15.0})).item<bool>());
|
||||
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}), array({21.0})).item<bool>());
|
||||
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}), array({9.0})).item<bool>());
|
||||
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}), array({3.0})).item<bool>());
|
||||
|
||||
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}, true), array({15.0}, {1, 1}))
|
||||
CHECK(allclose(norm(x, 1.0, std::vector<int>{0, 1}), array({21., 23., 25.}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}, true), array({21.0}, {1, 1}))
|
||||
CHECK(allclose(norm(x, 1.0, std::vector<int>{1, 2}), array({15., 42.}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}, true), array({9.0}, {1, 1}))
|
||||
CHECK(allclose(norm(x, -1.0, std::vector<int>{0, 1}), array({9., 11., 13.}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}, true), array({3.0}, {1, 1}))
|
||||
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9., 36.}))
|
||||
.item<bool>());
|
||||
|
||||
CHECK(array_equal(norm(arr_two_d, -1.0, {-2, -1}, false), array(9.0))
|
||||
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 0}), array({9., 12., 15.}))
|
||||
.item<bool>());
|
||||
CHECK(array_equal(norm(arr_two_d, 1.0, {-2, -1}, false), array(15.0))
|
||||
CHECK(allclose(norm(x, -1.0, std::vector<int>{2, 1}), array({3, 30}))
|
||||
.item<bool>());
|
||||
//
|
||||
CHECK(allclose(norm(arr_three_d, 1.0, {0, 1}), array({21., 23., 25.}))
|
||||
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
|
||||
.item<bool>());
|
||||
CHECK(
|
||||
allclose(norm(arr_three_d, 1.0, {1, 2}), array({15., 42.})).item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, -1.0, {0, 1}), array({9., 11., 13.}))
|
||||
.item<bool>());
|
||||
CHECK(
|
||||
allclose(norm(arr_three_d, -1.0, {1, 2}), array({9., 36.})).item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, -1.0, {1, 0}), array({9., 12., 15.}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, -1.0, {2, 1}), array({3, 30})).item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, -1.0, {1, 2}), array({9, 36})).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||
array arr_one_d({1, 2, 3});
|
||||
array arr_two_d = reshape(arange(9), {3, 3});
|
||||
array arr_three_d = reshape(arange(18), {2, 3, 3});
|
||||
array x({1, 2, 3});
|
||||
CHECK_THROWS(norm(x, "fro"));
|
||||
|
||||
CHECK(allclose(norm(arr_one_d, "inf", {}), array({3.0})).item<bool>());
|
||||
CHECK(allclose(norm(arr_one_d, "-inf", {}), array({1.0})).item<bool>());
|
||||
x = reshape(arange(9), {3, 3});
|
||||
CHECK_THROWS(norm(x, "bad ord"));
|
||||
|
||||
CHECK(allclose(norm(arr_two_d, "f", {0, 1}), array({14.2828568570857}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_two_d, "fro", {0, 1}), array({14.2828568570857}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_two_d, "inf", {0, 1}), array({21.0})).item<bool>());
|
||||
CHECK(allclose(norm(arr_two_d, "-inf", {0, 1}), array({3.0})).item<bool>());
|
||||
CHECK_EQ(
|
||||
norm(x, "f", std::vector<int>{0, 1}).item<float>(),
|
||||
doctest::Approx(14.2828568570857));
|
||||
CHECK_EQ(
|
||||
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
|
||||
doctest::Approx(14.2828568570857));
|
||||
|
||||
x = reshape(arange(18), {2, 3, 3});
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, "fro", {0, 1}),
|
||||
norm(x, "fro", std::vector<int>{0, 1}),
|
||||
array({22.24859546, 24.31049156, 26.43860813}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, "fro", {1, 2}), array({14.28285686, 39.7617907}))
|
||||
norm(x, "fro", std::vector<int>{1, 2}),
|
||||
array({14.28285686, 39.7617907}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, "f", {0, 1}),
|
||||
norm(x, "f", std::vector<int>{0, 1}),
|
||||
array({22.24859546, 24.31049156, 26.43860813}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(
|
||||
norm(arr_three_d, "f", {1, 0}),
|
||||
norm(x, "f", std::vector<int>{1, 0}),
|
||||
array({22.24859546, 24.31049156, 26.43860813}))
|
||||
.item<bool>());
|
||||
CHECK(
|
||||
allclose(norm(arr_three_d, "f", {1, 2}), array({14.28285686, 39.7617907}))
|
||||
.item<bool>());
|
||||
CHECK(
|
||||
allclose(norm(arr_three_d, "f", {2, 1}), array({14.28285686, 39.7617907}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, "inf", {0, 1}), array({36., 39., 42.}))
|
||||
CHECK(allclose(
|
||||
norm(x, "f", std::vector<int>{1, 2}),
|
||||
array({14.28285686, 39.7617907}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, "inf", {1, 2}), array({21., 48.}))
|
||||
CHECK(allclose(
|
||||
norm(x, "f", std::vector<int>{2, 1}),
|
||||
array({14.28285686, 39.7617907}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, "-inf", {0, 1}), array({9., 12., 15.}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, "-inf", {1, 2}), array({3., 30.}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, "-inf", {1, 0}), array({9., 11., 13.}))
|
||||
.item<bool>());
|
||||
CHECK(allclose(norm(arr_three_d, "-inf", {2, 1}), array({9., 36.}))
|
||||
.item<bool>());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user