some style and API consistency updates to linalg norm

This commit is contained in:
Awni Hannun 2023-12-26 10:54:59 -08:00
parent 4bae4a8239
commit f7cea9563d
10 changed files with 437 additions and 498 deletions

View File

@ -57,6 +57,7 @@ are the CPU and GPU.
python/random
python/transforms
python/fft
python/linalg
python/nn
python/optimizers
python/tree_utils

View File

@ -1,7 +1,7 @@
.. _linalg:
Linear Algebra
=====
==============
.. currentmodule:: mlx.core.linalg

View File

@ -1,47 +1,42 @@
// Copyright © 2023 Apple Inc.
#include <sstream>
#include <string>
#include <numeric>
#include <ostream>
#include <vector>
#include "mlx/array.h"
#include "mlx/linalg.h"
#include "mlx/ops.h"
#include "mlx/utils.h"
namespace mlx::core::linalg {
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
}
inline array vector_norm(
const array& a,
const double ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (ord == 0.0)
return sum(a != 0, axis, keepdims, s);
else if (ord == 1.0)
return sum(abs(a, s), axis, keepdims, s);
else if (ord == 2.0)
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
else
auto dtype = at_least_float(a.dtype());
if (ord == 0.0) {
return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);
} else if (ord == 1.0) {
return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
} else if (ord == 2.0) {
return sqrt(sum(square(a, s), axis, keepdims, s), s);
} else if (ord == std::numeric_limits<double>::infinity()) {
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
} else if (ord == -std::numeric_limits<double>::infinity()) {
return astype(min(abs(a, s), axis, keepdims, s), dtype, s);
} else {
return power(
sum(power(abs(a, s), array(ord), s), axis, keepdims, s),
array(1.0 / ord));
}
inline array vector_norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (ord == "inf")
return max(abs(a, s), axis, keepdims, s);
else if (ord == "-inf")
return min(abs(a, s), axis, keepdims, s);
std::ostringstream error_stream;
error_stream << "Invalid ord value " << ord;
throw std::invalid_argument(error_stream.str());
sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),
array(1.0 / ord, dtype),
s);
}
}
inline array matrix_norm(
@ -50,19 +45,30 @@ inline array matrix_norm(
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
auto dtype = at_least_float(a.dtype());
auto row_axis = axis[0];
auto col_axis = axis[1];
if (!keepdims && col_axis > row_axis)
if (!keepdims && col_axis > row_axis && col_axis > 0) {
col_axis -= 1;
if (ord == -1.0)
return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
if (ord == 1.0)
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
if (ord == 2.0 || ord == -2.0)
throw std::logic_error("Singular value norms are not implemented.");
std::ostringstream error_stream;
error_stream << "Invalid ord value " << ord << " for matrix norm";
throw std::invalid_argument(error_stream.str());
}
if (ord == -1.0) {
return astype(
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
dtype,
s);
} else if (ord == 1.0) {
return astype(
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
dtype,
s);
} else if (ord == 2.0 || ord == -2.0) {
throw std::runtime_error(
"[linalg::norm] Singular value norms are not implemented.");
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm";
throw std::invalid_argument(msg.str());
}
}
inline array matrix_norm(
@ -71,85 +77,77 @@ inline array matrix_norm(
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (ord == "f" || ord == "fro")
if (ord == "f" || ord == "fro") {
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
else if (ord == "inf")
return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s);
else if (ord == "-inf")
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
if (ord == "nuc")
throw std::logic_error("Nuclear norm is not implemented.");
std::ostringstream error_stream;
error_stream << "Invalid ord value " << ord << " for matrix norm";
throw std::invalid_argument(error_stream.str());
} else if (ord == "nuc") {
throw std::runtime_error(
"[linalg::norm] Nuclear norm not yet implemented.");
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm";
throw std::invalid_argument(msg.str());
}
}
array norm(
const array& a,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
auto num_axes = axis.size();
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
if (!axis) {
return norm(flatten(a, s), std::vector<int>{0}, keepdims, s);
}
if (num_axes == 0 || num_axes == 1 || num_axes == 2)
return sqrt(
sum(abs(a, s) * abs(a, s),
num_axes ? axis : get_reduce_axes({}, a.ndim()),
keepdims,
s),
s);
std::ostringstream error_stream;
error_stream << "Invalid axis values " << axis;
throw std::invalid_argument(error_stream.str());
if (axis.value().size() > 2) {
throw std::invalid_argument(
"[linalg::norm] Received too many axes for norm");
}
return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
}
array norm(
const array& a,
const double ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
std::vector<int> ax = axis;
if (axis.empty())
ax = get_reduce_axes({}, a.ndim());
else
ax = normalize_axes(ax, a.ndim());
auto num_axes = ax.size();
if (num_axes == 1)
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
std::vector<int> ax;
if (!axis) {
ax.resize(a.ndim());
std::iota(ax.begin(), ax.end(), 0);
} else {
ax = axis.value();
}
if (ax.size() == 1) {
return vector_norm(a, ord, ax, keepdims, s);
else if (num_axes == 2)
} else if (ax.size() == 2) {
return matrix_norm(a, ord, ax, keepdims, s);
std::ostringstream error_stream;
error_stream << "Invalid axis values " << ax;
throw std::invalid_argument(error_stream.str());
} else {
throw std::invalid_argument(
"[linalg::norm] Received too many axes for norm");
}
}
array norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
std::vector<int> ax = axis;
if (axis.empty())
ax = get_reduce_axes({}, a.ndim());
else
ax = normalize_axes(ax, a.ndim());
auto num_axes = ax.size();
if (num_axes == 1)
return vector_norm(a, ord, ax, keepdims, s);
else if (num_axes == 2)
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
std::vector<int> ax;
if (!axis) {
ax.resize(a.ndim());
std::iota(ax.begin(), ax.end(), 0);
} else {
ax = axis.value();
}
if (ax.size() != 2) {
std::ostringstream msg;
msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices,"
<< " but received " << ax.size() << " axis/axes.";
throw std::invalid_argument(msg.str());
}
return matrix_norm(a, ord, ax, keepdims, s);
std::ostringstream error_stream;
error_stream << "Invalid axis values " << ax;
throw std::invalid_argument(error_stream.str());
}
} // namespace mlx::core::linalg

View File

@ -2,27 +2,61 @@
#pragma once
#include <optional>
#include "array.h"
#include "device.h"
#include "ops.h"
#include "stream.h"
namespace mlx::core::linalg {
/*
* Compute vector or matrix norms.
*
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
* - If axis is not provided but ord is, then x must be either 1D or 2D.
* - If axis is provided, but ord is not, then the 2-norm is computed along the
* given axes. At most 2 axes can be specified.
* - If both axis and ord are provided, then the corresponding matrix of vector
* norm is computed. At most 2 axes can be specified.
*/
array norm(
const array& a,
const double ord,
const std::vector<int>& axis = {},
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array norm(
const array& a,
const double ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis = {},
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array norm(
const array& a,
const std::string& ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm(
const array& a,
const std::vector<int>& axis = {},
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
return norm(a, std::vector<int>{axis}, keepdims, s);
}
} // namespace mlx::core::linalg

View File

@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <numeric>
#include <sstream>
#include <vector>
@ -74,12 +73,6 @@ int normalize_axis(int axis, int ndim) {
}
return axis;
}
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim) {
std::vector<int> canonical;
for (int ax : axes)
canonical.push_back(normalize_axis(ax, ndim));
return canonical;
}
std::ostream& operator<<(std::ostream& os, const Device& d) {
os << "Device(";

View File

@ -2,7 +2,6 @@
#pragma once
#include <numeric>
#include "array.h"
#include "device.h"
#include "dtype.h"
@ -25,7 +24,6 @@ bool is_same_shape(const std::vector<array>& arrays);
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
*/
int normalize_axis(int axis, int ndim);
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim);
std::ostream& operator<<(std::ostream& os, const Device& d);
std::ostream& operator<<(std::ostream& os, const Stream& s);
@ -43,18 +41,4 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
return os << static_cast<float>(v);
}
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
if (std::holds_alternative<std::monostate>(v)) {
axes.resize(dims);
std::iota(axes.begin(), axes.end(), 0);
} else if (auto pv = std::get_if<int>(&v); pv) {
axes.push_back(*pv);
} else {
axes = std::get<std::vector<int>>(v);
}
return axes;
}
} // namespace mlx::core

View File

@ -7,7 +7,6 @@
#include "mlx/fft.h"
#include "mlx/ops.h"
#include "mlx/utils.h"
namespace py = pybind11;
using namespace py::literals;

View File

@ -26,106 +26,78 @@ using namespace mlx::core;
using namespace mlx::core::linalg;
void init_linalg(py::module_& parent_module) {
py::options options;
options.disable_function_signatures();
auto m =
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
m.def(
"norm",
[](const array& a,
const std::variant<std::monostate, int, double, std::string>& ord,
const std::variant<std::monostate, int, std::vector<int>>& axis,
const std::variant<std::monostate, int, double, std::string>& ord_,
const std::variant<std::monostate, int, std::vector<int>>& axis_,
const bool keepdims,
const StreamOrDevice stream) {
return std::visit(
overloaded{
[&](const double p) {
if (std::isinf((float)p) || std::isinf(p)) {
if (p > 0) {
return norm(
a,
"inf",
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
std::optional<std::vector<int>> axis = std::nullopt;
if (auto pv = std::get_if<int>(&axis_); pv) {
axis = std::vector<int>{*pv};
} else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {
axis = *pv;
}
return norm(
a,
"-inf",
get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
if (std::holds_alternative<std::monostate>(ord_)) {
return norm(a, axis, keepdims, stream);
} else {
if (auto pv = std::get_if<std::string>(&ord_); pv) {
return norm(a, *pv, axis, keepdims, stream);
}
double ord;
if (auto pv = std::get_if<int>(&ord_); pv) {
ord = *pv;
} else {
ord = std::get<double>(ord_);
}
return norm(a, ord, axis, keepdims, stream);
}
return norm(
a,
p,
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
},
[&](const std::string& p) {
return norm(
a,
p,
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
},
[&](const std::monostate _) {
return norm(
a,
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
}},
ord);
},
"a"_a,
py::pos_only(),
"ord"_a = none,
"axis"_a = none,
"keepdims"_a = false,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Matrix or vector norm.
This function is able to return matrix or vector norms,
depending on the value of the ``ord`` parameter.
This function computes vector or matrix norms depending on the value of
the ``ord`` and ``axis`` parameters.
Parameters
----------
a : array_like
Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord`
is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned.
ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional
Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None.
axis : {None, int, 2-tuple of ints}, optional.
If `axis` is an integer, it specifies the axis of `a` along which to
compute the vector norms. If `axis` is a 2-tuple, it specifies the
axes that hold 2-D matrices, and the matrix norms of these matrices
are computed. If `axis` is None then either a vector norm (when `a`
is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default
is None.
keepdims : bool, optional
If this is set to True, the axes which are normed over are left in the
result as dimensions with size one. With this option the result will
broadcast correctly against the original `a`.
Args:
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
2-norm of ``a.flatten`` will be returned.
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
If ``None``, the 2-norm will be computed along the given ``axis``.
Default: ``None``.
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
norms of these matrices are computed. If `axis` is ``None`` then
either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is
2-D) is returned. Default: ``None``.
keepdims (bool, optional): If ``True``, the axes which are normed over are
left in the result as dimensions with size one. Default ``False``.
Returns
-------
n : array
Norm of the matrix or vector(s).
Returns:
array: The output containing the norm(s).
Notes
-----
Notes:
For values of ``ord < 1``, the result is, strictly speaking, not a
mathematical 'norm', but it may still be useful for various numerical
mathematical norm, but it may still be useful for various numerical
purposes.
The following norms can be calculated:
@ -145,26 +117,25 @@ void init_linalg(py::module_& parent_module) {
other -- sum(abs(x)**ord)**(1./ord)
===== ============================ ==========================
.. warning::
Nuclear norm and norms based on singular values are not yet implemented.
The Frobenius norm is given by [1]_:
:math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
The nuclear norm is the sum of the singular values.
Both the Frobenius and nuclear norm orders are only defined for
matrices and raise a ValueError when ``a.ndim != 2``.
matrices and raise a ``ValueError`` when ``a.ndim != 2``.
References
----------
References:
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
Examples
--------
Examples:
>>> import mlx.core as mx
>>> from mlx.core import linalg as LA
>>> from mlx.core import linalg as la
>>> a = mx.arange(9) - 4
>>> a
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
@ -173,46 +144,46 @@ void init_linalg(py::module_& parent_module) {
array([[-4, -3, -2],
[-1, 0, 1],
[ 2, 3, 4]], dtype=int32)
>>> LA.norm(a)
>>> la.norm(a)
array(7.74597, dtype=float32)
>>> LA.norm(b)
>>> la.norm(b)
array(7.74597, dtype=float32)
>>> LA.norm(b, 'fro')
>>> la.norm(b, 'fro')
array(7.74597, dtype=float32)
>>> LA.norm(a, float("inf"))
array(4, dtype=int32)
>>> LA.norm(b, float("inf"))
array(9, dtype=int32)
>>> LA.norm(a, -float("inf"))
array(0, dtype=int32)
>>> LA.norm(b, -float("inf"))
array(2, dtype=int32)
>>> LA.norm(a, 1)
array(20, dtype=int32)
>>> LA.norm(b, 1)
array(7, dtype=int32)
>>> LA.norm(a, -1)
>>> la.norm(a, float("inf"))
array(4, dtype=float32)
>>> la.norm(b, float("inf"))
array(9, dtype=float32)
>>> la.norm(a, -float("inf"))
array(0, dtype=float32)
>>> LA.norm(b, -1)
array(6, dtype=int32)
>>> LA.norm(a, 2)
>>> la.norm(b, -float("inf"))
array(2, dtype=float32)
>>> la.norm(a, 1)
array(20, dtype=float32)
>>> la.norm(b, 1)
array(7, dtype=float32)
>>> la.norm(a, -1)
array(0, dtype=float32)
>>> la.norm(b, -1)
array(6, dtype=float32)
>>> la.norm(a, 2)
array(7.74597, dtype=float32)
>>> LA.norm(a, 3)
>>> la.norm(a, 3)
array(5.84804, dtype=float32)
>>> LA.norm(a, -3)
>>> la.norm(a, -3)
array(0, dtype=float32)
>>> c = mx.array([[ 1, 2, 3],
... [-1, 1, 4]])
>>> LA.norm(c, axis=0)
>>> la.norm(c, axis=0)
array([1.41421, 2.23607, 5], dtype=float32)
>>> LA.norm(c, axis=1)
>>> la.norm(c, axis=1)
array([3.74166, 4.24264], dtype=float32)
>>> LA.norm(c, ord=1, axis=1)
array([6, 6], dtype=int32)
>>> la.norm(c, ord=1, axis=1)
array([6, 6], dtype=float32)
>>> m = mx.arange(8).reshape(2,2,2)
>>> LA.norm(m, axis=(1,2))
>>> la.norm(m, axis=(1,2))
array([3.74166, 11.225], dtype=float32)
>>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :])
>>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
)pbdoc");
}

View File

@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <numeric>
#include <variant>
#include <pybind11/complex.h>
@ -13,10 +14,24 @@ namespace py = pybind11;
using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std::
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
static constexpr std::monostate none{};
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
if (std::holds_alternative<std::monostate>(v)) {
axes.resize(dims);
std::iota(axes.begin(), axes.end(), 0);
} else if (auto pv = std::get_if<int>(&v); pv) {
axes.push_back(*pv);
} else {
axes = std::get<std::vector<int>>(v);
}
return axes;
}
inline array to_array_with_accessor(py::object obj) {
if (py::hasattr(obj, "__mlx_array__")) {
return obj.attr("__mlx_array__")().cast<array>();

View File

@ -3,170 +3,155 @@
#include "doctest/doctest.h"
#include <cmath>
#include <iostream>
#include "mlx/linalg.h"
#include "mlx/mlx.h"
using namespace mlx::core;
using namespace mlx::core::linalg;
TEST_CASE("[mlx.core.linalg.norm] no ord") {
array arr_one_d({1, 2, 3});
array arr_two_d = reshape(arange(9), {3, 3});
array arr_three_d = reshape(arange(18), {2, 3, 3});
// Zero dimensions
array x(2.0);
CHECK_EQ(norm(x).item<float>(), 2.0f);
CHECK_THROWS(norm(x, 0));
CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item<bool>());
CHECK(array_equal(norm(arr_one_d, {0}, false), array(sqrt(1 + 4 + 9)))
.item<bool>());
x = array({1, 2, 3});
float expected = std::sqrt(1 + 4 + 9);
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, 0, false).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, -1, false).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, -1, true).ndim(), 1);
CHECK_THROWS(norm(x, 1));
x = reshape(arange(9), {3, 3});
expected =
std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8);
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
CHECK_EQ(
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
CHECK(array_equal(
norm(arr_two_d, {}, false),
array(sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
.item<bool>());
CHECK(array_equal(
norm(arr_two_d, {0}, false),
norm(x, 0, false),
array(
{sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
{std::sqrt(0 + 3 * 3 + 6 * 6),
std::sqrt(1 + 4 * 4 + 7 * 7),
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>());
CHECK(array_equal(
norm(arr_two_d, {1}, false),
CHECK(allclose(
norm(x, 1, false),
array(
{sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
{std::sqrt(0 + 1 + 2 * 2),
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
std::sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>());
CHECK(array_equal(
norm(arr_two_d, {0, 1}, false),
array(sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, {2}, false),
x = reshape(arange(18), {2, 3, 3});
CHECK(allclose(
norm(x, 2, false),
array(
{
sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8),
sqrt(9 * 9 + 10 * 10 + 11 * 11),
sqrt(12 * 12 + 13 * 13 + 14 * 14),
sqrt(15 * 15 + 16 * 16 + 17 * 17),
std::sqrt(0 + 1 + 2 * 2),
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
std::sqrt(6 * 6 + 7 * 7 + 8 * 8),
std::sqrt(9 * 9 + 10 * 10 + 11 * 11),
std::sqrt(12 * 12 + 13 * 13 + 14 * 14),
std::sqrt(15 * 15 + 16 * 16 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, {1}, false),
CHECK(allclose(
norm(x, std::vector<int>{1, 2}, false),
array(
{
sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8),
sqrt(9 * 9 + 12 * 12 + 15 * 15),
sqrt(10 * 10 + 13 * 13 + 16 * 16),
sqrt(11 * 11 + 14 * 14 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, {0}, false),
array(
{
sqrt(0 + 9 * 9),
sqrt(1 + 10 * 10),
sqrt(2 * 2 + 11 * 11),
sqrt(3 * 3 + 12 * 12),
sqrt(4 * 4 + 13 * 13),
sqrt(5 * 5 + 14 * 14),
sqrt(6 * 6 + 15 * 15),
sqrt(7 * 7 + 16 * 16),
sqrt(8 * 8 + 17 * 17),
},
{3, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, {1, 2}, false),
array(
{sqrt(
{std::sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
8 * 8),
sqrt(
std::sqrt(
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
15 * 15 + 16 * 16 + 17 * 17)},
{2}))
.item<bool>());
CHECK_THROWS(norm(x, std::vector<int>{0, 1, 2}));
}
TEST_CASE("[mlx.core.linalg.norm] double ord") {
array arr_one_d({1, 2, 3});
array arr_two_d = reshape(arange(9), {3, 3});
array arr_three_d = reshape(arange(18), {2, 3, 3});
CHECK_THROWS(norm(array(0), 2.0));
CHECK(array_equal(norm(arr_one_d, 2.0), array(sqrt(1 + 4 + 9))).item<bool>());
CHECK(array_equal(norm(arr_one_d, 1.0), array(1 + 2 + 3)).item<bool>());
CHECK(array_equal(norm(arr_one_d, 0.0), array(3)).item<bool>());
array x({1, 2, 3});
CHECK(array_equal(norm(arr_one_d, 2.0, {0}, false), array(sqrt(1 + 4 + 9)))
.item<bool>());
CHECK(array_equal(
norm(arr_two_d, 2.0, {0}, false),
float expected = std::sqrt(1 + 4 + 9);
CHECK_EQ(norm(x, 2.0).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, 2.0, 0).item<float>(), doctest::Approx(expected));
CHECK_THROWS(norm(x, 2.0, 1));
expected = 1 + 2 + 3;
CHECK_EQ(norm(x, 1.0).item<float>(), doctest::Approx(expected));
expected = 3;
CHECK_EQ(norm(x, 0.0).item<float>(), doctest::Approx(expected));
expected = 3;
CHECK_EQ(
norm(x, std::numeric_limits<double>::infinity()).item<float>(),
doctest::Approx(expected));
expected = 1;
CHECK_EQ(
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
doctest::Approx(expected));
x = reshape(arange(9), {3, 3});
CHECK(allclose(
norm(x, 2.0, 0, false),
array(
{sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
{std::sqrt(0 + 3 * 3 + 6 * 6),
std::sqrt(1 + 4 * 4 + 7 * 7),
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>());
CHECK(array_equal(
norm(arr_two_d, 2.0, {1}, false),
CHECK(allclose(
norm(x, 2.0, 1, false),
array(
{sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, 2.0, {2}, false),
array(
{
sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8),
sqrt(9 * 9 + 10 * 10 + 11 * 11),
sqrt(12 * 12 + 13 * 13 + 14 * 14),
sqrt(15 * 15 + 16 * 16 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, 2.0, {1}, false),
array(
{
sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8),
sqrt(9 * 9 + 12 * 12 + 15 * 15),
sqrt(10 * 10 + 13 * 13 + 16 * 16),
sqrt(11 * 11 + 14 * 14 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, 2.0, {0}, false),
array(
{
sqrt(0 + 9 * 9),
sqrt(1 + 10 * 10),
sqrt(2 * 2 + 11 * 11),
sqrt(3 * 3 + 12 * 12),
sqrt(4 * 4 + 13 * 13),
sqrt(5 * 5 + 14 * 14),
sqrt(6 * 6 + 15 * 15),
sqrt(7 * 7 + 16 * 16),
sqrt(8 * 8 + 17 * 17),
},
{3, 3}))
.item<bool>());
CHECK_EQ(
norm(x, 1.0, std::vector<int>{0, 1}).item<float>(),
doctest::Approx(15.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(21.0));
CHECK_EQ(
norm(x, -1.0, std::vector<int>{0, 1}).item<float>(),
doctest::Approx(9.0));
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(3.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
doctest::Approx(9.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
doctest::Approx(15.0));
x = reshape(arange(18), {2, 3, 3});
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
CHECK(allclose(
norm(arr_three_d, 3.0, {0}),
norm(x, 3.0, 0),
array(
{9.,
10.00333222,
@ -179,15 +164,8 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
17.57113899},
{3, 3}))
.item<bool>());
CHECK(
allclose(
norm(arr_three_d, 3.0, {1}),
array(
{6.24025147, 7.41685954, 8.6401226, 18., 19.39257164, 20.7915893},
{2, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 3.0, {2}),
norm(x, 3.0, 2),
array(
{2.08008382,
6.,
@ -197,110 +175,76 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
23.13593104},
{2, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 0.0, {0}),
array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
.item<bool>());
CHECK(
allclose(
norm(arr_three_d, 0.0, {1}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
.item<bool>());
CHECK(
allclose(
norm(arr_three_d, 0.0, {2}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 1.0, {0}),
norm(x, 1.0, 0),
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 1.0, {1}),
array({9., 12., 15., 36., 39., 42.}, {2, 3}))
CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, 1.0, {2}),
array({3., 12., 21., 30., 39., 48.}, {2, 3}))
CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}), array({15.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}), array({21.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}), array({9.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}), array({3.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}, true), array({15.0}, {1, 1}))
CHECK(allclose(norm(x, 1.0, std::vector<int>{0, 1}), array({21., 23., 25.}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}, true), array({21.0}, {1, 1}))
CHECK(allclose(norm(x, 1.0, std::vector<int>{1, 2}), array({15., 42.}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}, true), array({9.0}, {1, 1}))
CHECK(allclose(norm(x, -1.0, std::vector<int>{0, 1}), array({9., 11., 13.}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}, true), array({3.0}, {1, 1}))
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9., 36.}))
.item<bool>());
CHECK(array_equal(norm(arr_two_d, -1.0, {-2, -1}, false), array(9.0))
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 0}), array({9., 12., 15.}))
.item<bool>());
CHECK(array_equal(norm(arr_two_d, 1.0, {-2, -1}, false), array(15.0))
CHECK(allclose(norm(x, -1.0, std::vector<int>{2, 1}), array({3, 30}))
.item<bool>());
//
CHECK(allclose(norm(arr_three_d, 1.0, {0, 1}), array({21., 23., 25.}))
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
.item<bool>());
CHECK(
allclose(norm(arr_three_d, 1.0, {1, 2}), array({15., 42.})).item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {0, 1}), array({9., 11., 13.}))
.item<bool>());
CHECK(
allclose(norm(arr_three_d, -1.0, {1, 2}), array({9., 36.})).item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {1, 0}), array({9., 12., 15.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {2, 1}), array({3, 30})).item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {1, 2}), array({9, 36})).item<bool>());
}
TEST_CASE("[mlx.core.linalg.norm] string ord") {
array arr_one_d({1, 2, 3});
array arr_two_d = reshape(arange(9), {3, 3});
array arr_three_d = reshape(arange(18), {2, 3, 3});
array x({1, 2, 3});
CHECK_THROWS(norm(x, "fro"));
CHECK(allclose(norm(arr_one_d, "inf", {}), array({3.0})).item<bool>());
CHECK(allclose(norm(arr_one_d, "-inf", {}), array({1.0})).item<bool>());
x = reshape(arange(9), {3, 3});
CHECK_THROWS(norm(x, "bad ord"));
CHECK(allclose(norm(arr_two_d, "f", {0, 1}), array({14.2828568570857}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, "fro", {0, 1}), array({14.2828568570857}))
.item<bool>());
CHECK(allclose(norm(arr_two_d, "inf", {0, 1}), array({21.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, "-inf", {0, 1}), array({3.0})).item<bool>());
CHECK_EQ(
norm(x, "f", std::vector<int>{0, 1}).item<float>(),
doctest::Approx(14.2828568570857));
CHECK_EQ(
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
doctest::Approx(14.2828568570857));
x = reshape(arange(18), {2, 3, 3});
CHECK(allclose(
norm(arr_three_d, "fro", {0, 1}),
norm(x, "fro", std::vector<int>{0, 1}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, "fro", {1, 2}), array({14.28285686, 39.7617907}))
norm(x, "fro", std::vector<int>{1, 2}),
array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, "f", {0, 1}),
norm(x, "f", std::vector<int>{0, 1}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(arr_three_d, "f", {1, 0}),
norm(x, "f", std::vector<int>{1, 0}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(
allclose(norm(arr_three_d, "f", {1, 2}), array({14.28285686, 39.7617907}))
CHECK(allclose(
norm(x, "f", std::vector<int>{1, 2}),
array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(
allclose(norm(arr_three_d, "f", {2, 1}), array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "inf", {0, 1}), array({36., 39., 42.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "inf", {1, 2}), array({21., 48.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {0, 1}), array({9., 12., 15.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {1, 2}), array({3., 30.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {1, 0}), array({9., 11., 13.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {2, 1}), array({9., 36.}))
CHECK(allclose(
norm(x, "f", std::vector<int>{2, 1}),
array({14.28285686, 39.7617907}))
.item<bool>());
}