some style and API consistency updates to linalg norm

This commit is contained in:
Awni Hannun 2023-12-26 10:54:59 -08:00
parent 4bae4a8239
commit f7cea9563d
10 changed files with 437 additions and 498 deletions

View File

@ -57,6 +57,7 @@ are the CPU and GPU.
python/random python/random
python/transforms python/transforms
python/fft python/fft
python/linalg
python/nn python/nn
python/optimizers python/optimizers
python/tree_utils python/tree_utils

View File

@ -1,7 +1,7 @@
.. _linalg: .. _linalg:
Linear Algebra Linear Algebra
===== ==============
.. currentmodule:: mlx.core.linalg .. currentmodule:: mlx.core.linalg

View File

@ -1,47 +1,42 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <sstream> #include <numeric>
#include <string> #include <ostream>
#include <vector> #include <vector>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/linalg.h" #include "mlx/linalg.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/utils.h"
namespace mlx::core::linalg { namespace mlx::core::linalg {
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
}
inline array vector_norm( inline array vector_norm(
const array& a, const array& a,
const double ord, const double ord,
const std::vector<int>& axis, const std::vector<int>& axis,
bool keepdims, bool keepdims,
StreamOrDevice s) { StreamOrDevice s) {
if (ord == 0.0) auto dtype = at_least_float(a.dtype());
return sum(a != 0, axis, keepdims, s); if (ord == 0.0) {
else if (ord == 1.0) return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);
return sum(abs(a, s), axis, keepdims, s); } else if (ord == 1.0) {
else if (ord == 2.0) return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); } else if (ord == 2.0) {
else return sqrt(sum(square(a, s), axis, keepdims, s), s);
} else if (ord == std::numeric_limits<double>::infinity()) {
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
} else if (ord == -std::numeric_limits<double>::infinity()) {
return astype(min(abs(a, s), axis, keepdims, s), dtype, s);
} else {
return power( return power(
sum(power(abs(a, s), array(ord), s), axis, keepdims, s), sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),
array(1.0 / ord)); array(1.0 / ord, dtype),
} s);
}
inline array vector_norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (ord == "inf")
return max(abs(a, s), axis, keepdims, s);
else if (ord == "-inf")
return min(abs(a, s), axis, keepdims, s);
std::ostringstream error_stream;
error_stream << "Invalid ord value " << ord;
throw std::invalid_argument(error_stream.str());
} }
inline array matrix_norm( inline array matrix_norm(
@ -50,19 +45,30 @@ inline array matrix_norm(
const std::vector<int>& axis, const std::vector<int>& axis,
bool keepdims, bool keepdims,
StreamOrDevice s) { StreamOrDevice s) {
auto dtype = at_least_float(a.dtype());
auto row_axis = axis[0]; auto row_axis = axis[0];
auto col_axis = axis[1]; auto col_axis = axis[1];
if (!keepdims && col_axis > row_axis) if (!keepdims && col_axis > row_axis && col_axis > 0) {
col_axis -= 1; col_axis -= 1;
if (ord == -1.0) }
return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); if (ord == -1.0) {
if (ord == 1.0) return astype(
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
if (ord == 2.0 || ord == -2.0) dtype,
throw std::logic_error("Singular value norms are not implemented."); s);
std::ostringstream error_stream; } else if (ord == 1.0) {
error_stream << "Invalid ord value " << ord << " for matrix norm"; return astype(
throw std::invalid_argument(error_stream.str()); max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
dtype,
s);
} else if (ord == 2.0 || ord == -2.0) {
throw std::runtime_error(
"[linalg::norm] Singular value norms are not implemented.");
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm";
throw std::invalid_argument(msg.str());
}
} }
inline array matrix_norm( inline array matrix_norm(
@ -71,85 +77,77 @@ inline array matrix_norm(
const std::vector<int>& axis, const std::vector<int>& axis,
bool keepdims, bool keepdims,
StreamOrDevice s) { StreamOrDevice s) {
if (ord == "f" || ord == "fro") if (ord == "f" || ord == "fro") {
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s));
else if (ord == "inf") } else if (ord == "nuc") {
return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s); throw std::runtime_error(
else if (ord == "-inf") "[linalg::norm] Nuclear norm not yet implemented.");
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s); } else {
if (ord == "nuc") std::ostringstream msg;
throw std::logic_error("Nuclear norm is not implemented."); msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm";
std::ostringstream error_stream; throw std::invalid_argument(msg.str());
error_stream << "Invalid ord value " << ord << " for matrix norm"; }
throw std::invalid_argument(error_stream.str());
} }
array norm( array norm(
const array& a, const array& a,
const std::vector<int>& axis, const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims, bool keepdims /* = false */,
StreamOrDevice s) { StreamOrDevice s /* = {} */) {
auto num_axes = axis.size(); if (!axis) {
return norm(flatten(a, s), std::vector<int>{0}, keepdims, s);
}
if (num_axes == 0 || num_axes == 1 || num_axes == 2) if (axis.value().size() > 2) {
return sqrt( throw std::invalid_argument(
sum(abs(a, s) * abs(a, s), "[linalg::norm] Received too many axes for norm");
num_axes ? axis : get_reduce_axes({}, a.ndim()), }
keepdims, return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
s),
s);
std::ostringstream error_stream;
error_stream << "Invalid axis values " << axis;
throw std::invalid_argument(error_stream.str());
} }
array norm( array norm(
const array& a, const array& a,
const double ord, const double ord,
const std::vector<int>& axis, const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims, bool keepdims /* = false */,
StreamOrDevice s) { StreamOrDevice s /* = {} */) {
std::vector<int> ax = axis; std::vector<int> ax;
if (!axis) {
if (axis.empty()) ax.resize(a.ndim());
ax = get_reduce_axes({}, a.ndim()); std::iota(ax.begin(), ax.end(), 0);
else } else {
ax = normalize_axes(ax, a.ndim()); ax = axis.value();
}
auto num_axes = ax.size(); if (ax.size() == 1) {
if (num_axes == 1)
return vector_norm(a, ord, ax, keepdims, s); return vector_norm(a, ord, ax, keepdims, s);
else if (num_axes == 2) } else if (ax.size() == 2) {
return matrix_norm(a, ord, ax, keepdims, s); return matrix_norm(a, ord, ax, keepdims, s);
} else {
std::ostringstream error_stream; throw std::invalid_argument(
error_stream << "Invalid axis values " << ax; "[linalg::norm] Received too many axes for norm");
throw std::invalid_argument(error_stream.str()); }
} }
array norm( array norm(
const array& a, const array& a,
const std::string& ord, const std::string& ord,
const std::vector<int>& axis, const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims, bool keepdims /* = false */,
StreamOrDevice s) { StreamOrDevice s /* = {} */) {
std::vector<int> ax = axis; std::vector<int> ax;
if (!axis) {
if (axis.empty()) ax.resize(a.ndim());
ax = get_reduce_axes({}, a.ndim()); std::iota(ax.begin(), ax.end(), 0);
else } else {
ax = normalize_axes(ax, a.ndim()); ax = axis.value();
}
auto num_axes = ax.size(); if (ax.size() != 2) {
if (num_axes == 1) std::ostringstream msg;
return vector_norm(a, ord, ax, keepdims, s); msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices,"
else if (num_axes == 2) << " but received " << ax.size() << " axis/axes.";
return matrix_norm(a, ord, ax, keepdims, s); throw std::invalid_argument(msg.str());
}
std::ostringstream error_stream; return matrix_norm(a, ord, ax, keepdims, s);
error_stream << "Invalid axis values " << ax;
throw std::invalid_argument(error_stream.str());
} }
} // namespace mlx::core::linalg } // namespace mlx::core::linalg

View File

@ -2,27 +2,61 @@
#pragma once #pragma once
#include <optional>
#include "array.h" #include "array.h"
#include "device.h" #include "device.h"
#include "ops.h" #include "ops.h"
#include "stream.h" #include "stream.h"
namespace mlx::core::linalg { namespace mlx::core::linalg {
/*
* Compute vector or matrix norms.
*
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
* - If axis is not provided but ord is, then x must be either 1D or 2D.
* - If axis is provided, but ord is not, then the 2-norm is computed along the
* given axes. At most 2 axes can be specified.
* - If both axis and ord are provided, then the corresponding matrix of vector
* norm is computed. At most 2 axes can be specified.
*/
array norm( array norm(
const array& a, const array& a,
const double ord, const double ord,
const std::vector<int>& axis = {}, const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false, bool keepdims = false,
StreamOrDevice s = {}); StreamOrDevice s = {});
inline array norm(
const array& a,
const double ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm( array norm(
const array& a, const array& a,
const std::string& ord, const std::string& ord,
const std::vector<int>& axis = {}, const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false, bool keepdims = false,
StreamOrDevice s = {}); StreamOrDevice s = {});
inline array norm(
const array& a,
const std::string& ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm( array norm(
const array& a, const array& a,
const std::vector<int>& axis = {}, const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false, bool keepdims = false,
StreamOrDevice s = {}); StreamOrDevice s = {});
inline array
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
return norm(a, std::vector<int>{axis}, keepdims, s);
}
} // namespace mlx::core::linalg } // namespace mlx::core::linalg

View File

@ -1,6 +1,5 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#include <numeric>
#include <sstream> #include <sstream>
#include <vector> #include <vector>
@ -74,12 +73,6 @@ int normalize_axis(int axis, int ndim) {
} }
return axis; return axis;
} }
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim) {
std::vector<int> canonical;
for (int ax : axes)
canonical.push_back(normalize_axis(ax, ndim));
return canonical;
}
std::ostream& operator<<(std::ostream& os, const Device& d) { std::ostream& operator<<(std::ostream& os, const Device& d) {
os << "Device("; os << "Device(";

View File

@ -2,7 +2,6 @@
#pragma once #pragma once
#include <numeric>
#include "array.h" #include "array.h"
#include "device.h" #include "device.h"
#include "dtype.h" #include "dtype.h"
@ -25,7 +24,6 @@ bool is_same_shape(const std::vector<array>& arrays);
* https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html * https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html
*/ */
int normalize_axis(int axis, int ndim); int normalize_axis(int axis, int ndim);
std::vector<int> normalize_axes(const std::vector<int>& axes, int ndim);
std::ostream& operator<<(std::ostream& os, const Device& d); std::ostream& operator<<(std::ostream& os, const Device& d);
std::ostream& operator<<(std::ostream& os, const Stream& s); std::ostream& operator<<(std::ostream& os, const Stream& s);
@ -43,18 +41,4 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
return os << static_cast<float>(v); return os << static_cast<float>(v);
} }
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
if (std::holds_alternative<std::monostate>(v)) {
axes.resize(dims);
std::iota(axes.begin(), axes.end(), 0);
} else if (auto pv = std::get_if<int>(&v); pv) {
axes.push_back(*pv);
} else {
axes = std::get<std::vector<int>>(v);
}
return axes;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -7,7 +7,6 @@
#include "mlx/fft.h" #include "mlx/fft.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/utils.h"
namespace py = pybind11; namespace py = pybind11;
using namespace py::literals; using namespace py::literals;

View File

@ -26,193 +26,164 @@ using namespace mlx::core;
using namespace mlx::core::linalg; using namespace mlx::core::linalg;
void init_linalg(py::module_& parent_module) { void init_linalg(py::module_& parent_module) {
py::options options;
options.disable_function_signatures();
auto m = auto m =
parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra.");
m.def( m.def(
"norm", "norm",
[](const array& a, [](const array& a,
const std::variant<std::monostate, int, double, std::string>& ord, const std::variant<std::monostate, int, double, std::string>& ord_,
const std::variant<std::monostate, int, std::vector<int>>& axis, const std::variant<std::monostate, int, std::vector<int>>& axis_,
const bool keepdims, const bool keepdims,
const StreamOrDevice stream) { const StreamOrDevice stream) {
return std::visit( std::optional<std::vector<int>> axis = std::nullopt;
overloaded{ if (auto pv = std::get_if<int>(&axis_); pv) {
[&](const double p) { axis = std::vector<int>{*pv};
if (std::isinf((float)p) || std::isinf(p)) { } else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {
if (p > 0) { axis = *pv;
return norm( }
a,
"inf", if (std::holds_alternative<std::monostate>(ord_)) {
std::holds_alternative<std::monostate>(axis) return norm(a, axis, keepdims, stream);
? std::vector<int>() } else {
: get_reduce_axes(axis, a.ndim()), if (auto pv = std::get_if<std::string>(&ord_); pv) {
keepdims, return norm(a, *pv, axis, keepdims, stream);
stream); }
} double ord;
return norm( if (auto pv = std::get_if<int>(&ord_); pv) {
a, ord = *pv;
"-inf", } else {
get_reduce_axes(axis, a.ndim()), ord = std::get<double>(ord_);
keepdims, }
stream); return norm(a, ord, axis, keepdims, stream);
} }
return norm(
a,
p,
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
},
[&](const std::string& p) {
return norm(
a,
p,
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
},
[&](const std::monostate _) {
return norm(
a,
std::holds_alternative<std::monostate>(axis)
? std::vector<int>()
: get_reduce_axes(axis, a.ndim()),
keepdims,
stream);
}},
ord);
}, },
"a"_a, "a"_a,
py::pos_only(),
"ord"_a = none, "ord"_a = none,
"axis"_a = none, "axis"_a = none,
"keepdims"_a = false, "keepdims"_a = false,
py::kw_only(),
"stream"_a = none, "stream"_a = none,
R"pbdoc( R"pbdoc(
Matrix or vector norm. norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
This function is able to return matrix or vector norms, Matrix or vector norm.
depending on the value of the ``ord`` parameter.
Parameters This function computes vector or matrix norms depending on the value of
---------- the ``ord`` and ``axis`` parameters.
a : array_like
Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord`
is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned.
ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional
Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None.
axis : {None, int, 2-tuple of ints}, optional.
If `axis` is an integer, it specifies the axis of `a` along which to
compute the vector norms. If `axis` is a 2-tuple, it specifies the
axes that hold 2-D matrices, and the matrix norms of these matrices
are computed. If `axis` is None then either a vector norm (when `a`
is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default
is None.
keepdims : bool, optional
If this is set to True, the axes which are normed over are left in the
result as dimensions with size one. With this option the result will
broadcast correctly against the original `a`.
Returns Args:
------- a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
n : array unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
Norm of the matrix or vector(s). 2-norm of ``a.flatten`` will be returned.
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
If ``None``, the 2-norm will be computed along the given ``axis``.
Default: ``None``.
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
norms of these matrices are computed. If `axis` is ``None`` then
either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is
2-D) is returned. Default: ``None``.
keepdims (bool, optional): If ``True``, the axes which are normed over are
left in the result as dimensions with size one. Default ``False``.
Notes Returns:
----- array: The output containing the norm(s).
For values of ``ord < 1``, the result is, strictly speaking, not a
mathematical 'norm', but it may still be useful for various numerical
purposes.
The following norms can be calculated: Notes:
For values of ``ord < 1``, the result is, strictly speaking, not a
mathematical norm, but it may still be useful for various numerical
purposes.
===== ============================ ========================== The following norms can be calculated:
ord norm for matrices norm for vectors
===== ============================ ==========================
None Frobenius norm 2-norm
'fro' Frobenius norm --
inf max(sum(abs(x), axis=1)) max(abs(x))
-inf min(sum(abs(x), axis=1)) min(abs(x))
0 -- sum(x != 0)
1 max(sum(abs(x), axis=0)) as below
-1 min(sum(abs(x), axis=0)) as below
2 2-norm (largest sing. value) as below
-2 smallest singular value as below
other -- sum(abs(x)**ord)**(1./ord)
===== ============================ ==========================
Nuclear norm and norms based on singular values are not yet implemented. ===== ============================ ==========================
ord norm for matrices norm for vectors
===== ============================ ==========================
None Frobenius norm 2-norm
'fro' Frobenius norm --
inf max(sum(abs(x), axis=1)) max(abs(x))
-inf min(sum(abs(x), axis=1)) min(abs(x))
0 -- sum(x != 0)
1 max(sum(abs(x), axis=0)) as below
-1 min(sum(abs(x), axis=0)) as below
2 2-norm (largest sing. value) as below
-2 smallest singular value as below
other -- sum(abs(x)**ord)**(1./ord)
===== ============================ ==========================
The Frobenius norm is given by [1]_: .. warning::
Nuclear norm and norms based on singular values are not yet implemented.
:math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` The Frobenius norm is given by [1]_:
The nuclear norm is the sum of the singular values. :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
Both the Frobenius and nuclear norm orders are only defined for The nuclear norm is the sum of the singular values.
matrices and raise a ValueError when ``a.ndim != 2``.
References Both the Frobenius and nuclear norm orders are only defined for
---------- matrices and raise a ``ValueError`` when ``a.ndim != 2``.
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
Examples References:
-------- .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
>>> import mlx.core as mx Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
>>> from mlx.core import linalg as LA
>>> a = mx.arange(9) - 4 Examples:
>>> a >>> import mlx.core as mx
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) >>> from mlx.core import linalg as la
>>> b = a.reshape((3,3)) >>> a = mx.arange(9) - 4
>>> b >>> a
array([[-4, -3, -2], array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
[-1, 0, 1], >>> b = a.reshape((3,3))
[ 2, 3, 4]], dtype=int32) >>> b
>>> LA.norm(a) array([[-4, -3, -2],
array(7.74597, dtype=float32) [-1, 0, 1],
>>> LA.norm(b) [ 2, 3, 4]], dtype=int32)
array(7.74597, dtype=float32) >>> la.norm(a)
>>> LA.norm(b, 'fro') array(7.74597, dtype=float32)
array(7.74597, dtype=float32) >>> la.norm(b)
>>> LA.norm(a, float("inf")) array(7.74597, dtype=float32)
array(4, dtype=int32) >>> la.norm(b, 'fro')
>>> LA.norm(b, float("inf")) array(7.74597, dtype=float32)
array(9, dtype=int32) >>> la.norm(a, float("inf"))
>>> LA.norm(a, -float("inf")) array(4, dtype=float32)
array(0, dtype=int32) >>> la.norm(b, float("inf"))
>>> LA.norm(b, -float("inf")) array(9, dtype=float32)
array(2, dtype=int32) >>> la.norm(a, -float("inf"))
>>> LA.norm(a, 1) array(0, dtype=float32)
array(20, dtype=int32) >>> la.norm(b, -float("inf"))
>>> LA.norm(b, 1) array(2, dtype=float32)
array(7, dtype=int32) >>> la.norm(a, 1)
>>> LA.norm(a, -1) array(20, dtype=float32)
array(0, dtype=float32) >>> la.norm(b, 1)
>>> LA.norm(b, -1) array(7, dtype=float32)
array(6, dtype=int32) >>> la.norm(a, -1)
>>> LA.norm(a, 2) array(0, dtype=float32)
array(7.74597, dtype=float32) >>> la.norm(b, -1)
>>> LA.norm(a, 3) array(6, dtype=float32)
array(5.84804, dtype=float32) >>> la.norm(a, 2)
>>> LA.norm(a, -3) array(7.74597, dtype=float32)
array(0, dtype=float32) >>> la.norm(a, 3)
>>> c = mx.array([[ 1, 2, 3], array(5.84804, dtype=float32)
... [-1, 1, 4]]) >>> la.norm(a, -3)
>>> LA.norm(c, axis=0) array(0, dtype=float32)
array([1.41421, 2.23607, 5], dtype=float32) >>> c = mx.array([[ 1, 2, 3],
>>> LA.norm(c, axis=1) ... [-1, 1, 4]])
array([3.74166, 4.24264], dtype=float32) >>> la.norm(c, axis=0)
>>> LA.norm(c, ord=1, axis=1) array([1.41421, 2.23607, 5], dtype=float32)
array([6, 6], dtype=int32) >>> la.norm(c, axis=1)
>>> m = mx.arange(8).reshape(2,2,2) array([3.74166, 4.24264], dtype=float32)
>>> LA.norm(m, axis=(1,2)) >>> la.norm(c, ord=1, axis=1)
array([3.74166, 11.225], dtype=float32) array([6, 6], dtype=float32)
>>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) >>> m = mx.arange(8).reshape(2,2,2)
(array(3.74166, dtype=float32), array(11.225, dtype=float32)) >>> la.norm(m, axis=(1,2))
)pbdoc"); array([3.74166, 11.225], dtype=float32)
>>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
)pbdoc");
} }

View File

@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include <numeric>
#include <variant> #include <variant>
#include <pybind11/complex.h> #include <pybind11/complex.h>
@ -13,10 +14,24 @@ namespace py = pybind11;
using namespace mlx::core; using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std:: using ScalarOrArray = std::
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>; variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
static constexpr std::monostate none{}; static constexpr std::monostate none{};
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
if (std::holds_alternative<std::monostate>(v)) {
axes.resize(dims);
std::iota(axes.begin(), axes.end(), 0);
} else if (auto pv = std::get_if<int>(&v); pv) {
axes.push_back(*pv);
} else {
axes = std::get<std::vector<int>>(v);
}
return axes;
}
inline array to_array_with_accessor(py::object obj) { inline array to_array_with_accessor(py::object obj) {
if (py::hasattr(obj, "__mlx_array__")) { if (py::hasattr(obj, "__mlx_array__")) {
return obj.attr("__mlx_array__")().cast<array>(); return obj.attr("__mlx_array__")().cast<array>();

View File

@ -3,170 +3,155 @@
#include "doctest/doctest.h" #include "doctest/doctest.h"
#include <cmath> #include <cmath>
#include <iostream>
#include "mlx/linalg.h"
#include "mlx/mlx.h" #include "mlx/mlx.h"
using namespace mlx::core; using namespace mlx::core;
using namespace mlx::core::linalg; using namespace mlx::core::linalg;
TEST_CASE("[mlx.core.linalg.norm] no ord") { TEST_CASE("[mlx.core.linalg.norm] no ord") {
array arr_one_d({1, 2, 3}); // Zero dimensions
array arr_two_d = reshape(arange(9), {3, 3}); array x(2.0);
array arr_three_d = reshape(arange(18), {2, 3, 3}); CHECK_EQ(norm(x).item<float>(), 2.0f);
CHECK_THROWS(norm(x, 0));
CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item<bool>()); x = array({1, 2, 3});
CHECK(array_equal(norm(arr_one_d, {0}, false), array(sqrt(1 + 4 + 9))) float expected = std::sqrt(1 + 4 + 9);
.item<bool>()); CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, 0, false).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, -1, false).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, -1, true).ndim(), 1);
CHECK_THROWS(norm(x, 1));
x = reshape(arange(9), {3, 3});
expected =
std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8);
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
CHECK_EQ(
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
CHECK(array_equal( CHECK(array_equal(
norm(arr_two_d, {}, false), norm(x, 0, false),
array(sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8)))
.item<bool>());
CHECK(array_equal(
norm(arr_two_d, {0}, false),
array( array(
{sqrt(0 + 3 * 3 + 6 * 6), {std::sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7), std::sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8)})) std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(allclose(
norm(arr_two_d, {1}, false), norm(x, 1, false),
array( array(
{sqrt(0 + 1 + 2 * 2), {std::sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5), std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8)})) std::sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>()); .item<bool>());
CHECK(array_equal(
norm(arr_two_d, {0, 1}, false), x = reshape(arange(18), {2, 3, 3});
array(sqrt( CHECK(allclose(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) norm(x, 2, false),
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, {2}, false),
array( array(
{ {
sqrt(0 + 1 + 2 * 2), std::sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5), std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8), std::sqrt(6 * 6 + 7 * 7 + 8 * 8),
sqrt(9 * 9 + 10 * 10 + 11 * 11), std::sqrt(9 * 9 + 10 * 10 + 11 * 11),
sqrt(12 * 12 + 13 * 13 + 14 * 14), std::sqrt(12 * 12 + 13 * 13 + 14 * 14),
sqrt(15 * 15 + 16 * 16 + 17 * 17), std::sqrt(15 * 15 + 16 * 16 + 17 * 17),
}, },
{2, 3})) {2, 3}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(allclose(
norm(arr_three_d, {1}, false), norm(x, std::vector<int>{1, 2}, false),
array( array(
{ {std::sqrt(
sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8),
sqrt(9 * 9 + 12 * 12 + 15 * 15),
sqrt(10 * 10 + 13 * 13 + 16 * 16),
sqrt(11 * 11 + 14 * 14 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, {0}, false),
array(
{
sqrt(0 + 9 * 9),
sqrt(1 + 10 * 10),
sqrt(2 * 2 + 11 * 11),
sqrt(3 * 3 + 12 * 12),
sqrt(4 * 4 + 13 * 13),
sqrt(5 * 5 + 14 * 14),
sqrt(6 * 6 + 15 * 15),
sqrt(7 * 7 + 16 * 16),
sqrt(8 * 8 + 17 * 17),
},
{3, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, {1, 2}, false),
array(
{sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
8 * 8), 8 * 8),
sqrt( std::sqrt(
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 + 9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
15 * 15 + 16 * 16 + 17 * 17)}, 15 * 15 + 16 * 16 + 17 * 17)},
{2})) {2}))
.item<bool>()); .item<bool>());
CHECK_THROWS(norm(x, std::vector<int>{0, 1, 2}));
} }
TEST_CASE("[mlx.core.linalg.norm] double ord") { TEST_CASE("[mlx.core.linalg.norm] double ord") {
array arr_one_d({1, 2, 3}); CHECK_THROWS(norm(array(0), 2.0));
array arr_two_d = reshape(arange(9), {3, 3});
array arr_three_d = reshape(arange(18), {2, 3, 3});
CHECK(array_equal(norm(arr_one_d, 2.0), array(sqrt(1 + 4 + 9))).item<bool>()); array x({1, 2, 3});
CHECK(array_equal(norm(arr_one_d, 1.0), array(1 + 2 + 3)).item<bool>());
CHECK(array_equal(norm(arr_one_d, 0.0), array(3)).item<bool>());
CHECK(array_equal(norm(arr_one_d, 2.0, {0}, false), array(sqrt(1 + 4 + 9))) float expected = std::sqrt(1 + 4 + 9);
.item<bool>()); CHECK_EQ(norm(x, 2.0).item<float>(), doctest::Approx(expected));
CHECK(array_equal( CHECK_EQ(norm(x, 2.0, 0).item<float>(), doctest::Approx(expected));
norm(arr_two_d, 2.0, {0}, false), CHECK_THROWS(norm(x, 2.0, 1));
expected = 1 + 2 + 3;
CHECK_EQ(norm(x, 1.0).item<float>(), doctest::Approx(expected));
expected = 3;
CHECK_EQ(norm(x, 0.0).item<float>(), doctest::Approx(expected));
expected = 3;
CHECK_EQ(
norm(x, std::numeric_limits<double>::infinity()).item<float>(),
doctest::Approx(expected));
expected = 1;
CHECK_EQ(
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
doctest::Approx(expected));
x = reshape(arange(9), {3, 3});
CHECK(allclose(
norm(x, 2.0, 0, false),
array( array(
{sqrt(0 + 3 * 3 + 6 * 6), {std::sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7), std::sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8)})) std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>()); .item<bool>());
CHECK(array_equal( CHECK(allclose(
norm(arr_two_d, 2.0, {1}, false), norm(x, 2.0, 1, false),
array( array(
{sqrt(0 + 1 + 2 * 2), {sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5), sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8)})) sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>()); .item<bool>());
CHECK(array_equal(
norm(arr_three_d, 2.0, {2}, false),
array(
{
sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8),
sqrt(9 * 9 + 10 * 10 + 11 * 11),
sqrt(12 * 12 + 13 * 13 + 14 * 14),
sqrt(15 * 15 + 16 * 16 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, 2.0, {1}, false),
array(
{
sqrt(0 + 3 * 3 + 6 * 6),
sqrt(1 + 4 * 4 + 7 * 7),
sqrt(2 * 2 + 5 * 5 + 8 * 8),
sqrt(9 * 9 + 12 * 12 + 15 * 15),
sqrt(10 * 10 + 13 * 13 + 16 * 16),
sqrt(11 * 11 + 14 * 14 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(array_equal(
norm(arr_three_d, 2.0, {0}, false),
array(
{
sqrt(0 + 9 * 9),
sqrt(1 + 10 * 10),
sqrt(2 * 2 + 11 * 11),
sqrt(3 * 3 + 12 * 12),
sqrt(4 * 4 + 13 * 13),
sqrt(5 * 5 + 14 * 14),
sqrt(6 * 6 + 15 * 15),
sqrt(7 * 7 + 16 * 16),
sqrt(8 * 8 + 17 * 17),
},
{3, 3}))
.item<bool>());
CHECK_EQ(
norm(x, 1.0, std::vector<int>{0, 1}).item<float>(),
doctest::Approx(15.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(21.0));
CHECK_EQ(
norm(x, -1.0, std::vector<int>{0, 1}).item<float>(),
doctest::Approx(9.0));
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(3.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
doctest::Approx(9.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
doctest::Approx(15.0));
x = reshape(arange(18), {2, 3, 3});
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
CHECK(allclose( CHECK(allclose(
norm(arr_three_d, 3.0, {0}), norm(x, 3.0, 0),
array( array(
{9., {9.,
10.00333222, 10.00333222,
@ -179,15 +164,8 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
17.57113899}, 17.57113899},
{3, 3})) {3, 3}))
.item<bool>()); .item<bool>());
CHECK(
allclose(
norm(arr_three_d, 3.0, {1}),
array(
{6.24025147, 7.41685954, 8.6401226, 18., 19.39257164, 20.7915893},
{2, 3}))
.item<bool>());
CHECK(allclose( CHECK(allclose(
norm(arr_three_d, 3.0, {2}), norm(x, 3.0, 2),
array( array(
{2.08008382, {2.08008382,
6., 6.,
@ -197,110 +175,76 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
23.13593104}, 23.13593104},
{2, 3})) {2, 3}))
.item<bool>()); .item<bool>());
CHECK(allclose( CHECK(
norm(arr_three_d, 0.0, {0}), allclose(
array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3})) norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
.item<bool>());
CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>()); .item<bool>());
CHECK(
allclose(
norm(arr_three_d, 0.0, {1}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(
allclose(
norm(arr_three_d, 0.0, {2}), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(allclose( CHECK(allclose(
norm(arr_three_d, 1.0, {0}), norm(x, 1.0, 0),
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3})) array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
.item<bool>()); .item<bool>());
CHECK(allclose( CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3}))
norm(arr_three_d, 1.0, {1}),
array({9., 12., 15., 36., 39., 42.}, {2, 3}))
.item<bool>()); .item<bool>());
CHECK(allclose( CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3}))
norm(arr_three_d, 1.0, {2}),
array({3., 12., 21., 30., 39., 48.}, {2, 3}))
.item<bool>()); .item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}), array({15.0})).item<bool>()); CHECK(allclose(norm(x, 1.0, std::vector<int>{0, 1}), array({21., 23., 25.}))
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}), array({21.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}), array({9.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}), array({3.0})).item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}, true), array({15.0}, {1, 1}))
.item<bool>()); .item<bool>());
CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}, true), array({21.0}, {1, 1})) CHECK(allclose(norm(x, 1.0, std::vector<int>{1, 2}), array({15., 42.}))
.item<bool>()); .item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}, true), array({9.0}, {1, 1})) CHECK(allclose(norm(x, -1.0, std::vector<int>{0, 1}), array({9., 11., 13.}))
.item<bool>()); .item<bool>());
CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}, true), array({3.0}, {1, 1})) CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9., 36.}))
.item<bool>()); .item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 0}), array({9., 12., 15.}))
CHECK(array_equal(norm(arr_two_d, -1.0, {-2, -1}, false), array(9.0))
.item<bool>()); .item<bool>());
CHECK(array_equal(norm(arr_two_d, 1.0, {-2, -1}, false), array(15.0)) CHECK(allclose(norm(x, -1.0, std::vector<int>{2, 1}), array({3, 30}))
.item<bool>()); .item<bool>());
// CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
CHECK(allclose(norm(arr_three_d, 1.0, {0, 1}), array({21., 23., 25.}))
.item<bool>()); .item<bool>());
CHECK(
allclose(norm(arr_three_d, 1.0, {1, 2}), array({15., 42.})).item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {0, 1}), array({9., 11., 13.}))
.item<bool>());
CHECK(
allclose(norm(arr_three_d, -1.0, {1, 2}), array({9., 36.})).item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {1, 0}), array({9., 12., 15.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {2, 1}), array({3, 30})).item<bool>());
CHECK(allclose(norm(arr_three_d, -1.0, {1, 2}), array({9, 36})).item<bool>());
} }
TEST_CASE("[mlx.core.linalg.norm] string ord") { TEST_CASE("[mlx.core.linalg.norm] string ord") {
array arr_one_d({1, 2, 3}); array x({1, 2, 3});
array arr_two_d = reshape(arange(9), {3, 3}); CHECK_THROWS(norm(x, "fro"));
array arr_three_d = reshape(arange(18), {2, 3, 3});
CHECK(allclose(norm(arr_one_d, "inf", {}), array({3.0})).item<bool>()); x = reshape(arange(9), {3, 3});
CHECK(allclose(norm(arr_one_d, "-inf", {}), array({1.0})).item<bool>()); CHECK_THROWS(norm(x, "bad ord"));
CHECK(allclose(norm(arr_two_d, "f", {0, 1}), array({14.2828568570857})) CHECK_EQ(
.item<bool>()); norm(x, "f", std::vector<int>{0, 1}).item<float>(),
CHECK(allclose(norm(arr_two_d, "fro", {0, 1}), array({14.2828568570857})) doctest::Approx(14.2828568570857));
.item<bool>()); CHECK_EQ(
CHECK(allclose(norm(arr_two_d, "inf", {0, 1}), array({21.0})).item<bool>()); norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
CHECK(allclose(norm(arr_two_d, "-inf", {0, 1}), array({3.0})).item<bool>()); doctest::Approx(14.2828568570857));
x = reshape(arange(18), {2, 3, 3});
CHECK(allclose( CHECK(allclose(
norm(arr_three_d, "fro", {0, 1}), norm(x, "fro", std::vector<int>{0, 1}),
array({22.24859546, 24.31049156, 26.43860813})) array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>()); .item<bool>());
CHECK(allclose( CHECK(allclose(
norm(arr_three_d, "fro", {1, 2}), array({14.28285686, 39.7617907})) norm(x, "fro", std::vector<int>{1, 2}),
array({14.28285686, 39.7617907}))
.item<bool>()); .item<bool>());
CHECK(allclose( CHECK(allclose(
norm(arr_three_d, "f", {0, 1}), norm(x, "f", std::vector<int>{0, 1}),
array({22.24859546, 24.31049156, 26.43860813})) array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>()); .item<bool>());
CHECK(allclose( CHECK(allclose(
norm(arr_three_d, "f", {1, 0}), norm(x, "f", std::vector<int>{1, 0}),
array({22.24859546, 24.31049156, 26.43860813})) array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>()); .item<bool>());
CHECK( CHECK(allclose(
allclose(norm(arr_three_d, "f", {1, 2}), array({14.28285686, 39.7617907})) norm(x, "f", std::vector<int>{1, 2}),
.item<bool>()); array({14.28285686, 39.7617907}))
CHECK(
allclose(norm(arr_three_d, "f", {2, 1}), array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "inf", {0, 1}), array({36., 39., 42.}))
.item<bool>()); .item<bool>());
CHECK(allclose(norm(arr_three_d, "inf", {1, 2}), array({21., 48.})) CHECK(allclose(
.item<bool>()); norm(x, "f", std::vector<int>{2, 1}),
CHECK(allclose(norm(arr_three_d, "-inf", {0, 1}), array({9., 12., 15.})) array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {1, 2}), array({3., 30.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {1, 0}), array({9., 11., 13.}))
.item<bool>());
CHECK(allclose(norm(arr_three_d, "-inf", {2, 1}), array({9., 36.}))
.item<bool>()); .item<bool>());
} }