linalg.norm (#187)

* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gabrijel Boduljak 2023-12-27 04:42:04 +01:00 committed by GitHub
parent 447bc089b9
commit 6b0d30bb85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 780 additions and 0 deletions

View File

@ -57,6 +57,7 @@ are the CPU and GPU.
python/random python/random
python/transforms python/transforms
python/fft python/fft
python/linalg
python/nn python/nn
python/optimizers python/optimizers
python/tree_utils python/tree_utils

View File

@ -0,0 +1,11 @@
.. _linalg:
Linear Algebra
==============
.. currentmodule:: mlx.core.linalg
.. autosummary::
:toctree: _autosummary
norm

View File

@ -14,6 +14,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
) )

175
mlx/linalg.cpp Normal file
View File

@ -0,0 +1,175 @@
// Copyright © 2023 Apple Inc.
#include <numeric>
#include <ostream>
#include <vector>
#include "mlx/dtype.h"
#include "mlx/linalg.h"
namespace mlx::core::linalg {
Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
}
inline array l2_norm(
const array& a,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (is_complex(a.dtype())) {
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
} else {
return sqrt(sum(square(a, s), axis, keepdims, s), s);
}
}
inline array vector_norm(
const array& a,
const double ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
auto dtype = at_least_float(a.dtype());
if (ord == 0.0) {
return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);
} else if (ord == 1.0) {
return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
} else if (ord == 2.0) {
return l2_norm(a, axis, keepdims, s);
} else if (ord == std::numeric_limits<double>::infinity()) {
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
} else if (ord == -std::numeric_limits<double>::infinity()) {
return astype(min(abs(a, s), axis, keepdims, s), dtype, s);
} else {
return power(
sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),
array(1.0 / ord, dtype),
s);
}
}
inline array matrix_norm(
const array& a,
const double ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
auto dtype = at_least_float(a.dtype());
auto row_axis = axis[0];
auto col_axis = axis[1];
if (ord == -1.0) {
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
return astype(
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
dtype,
s);
} else if (ord == 1.0) {
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
return astype(
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
dtype,
s);
} else if (ord == std::numeric_limits<double>::infinity()) {
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
return astype(
max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
dtype,
s);
} else if (ord == -std::numeric_limits<double>::infinity()) {
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
return astype(
min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
dtype,
s);
} else if (ord == 2.0 || ord == -2.0) {
throw std::runtime_error(
"[linalg::norm] Singular value norms are not implemented.");
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
throw std::invalid_argument(msg.str());
}
}
inline array matrix_norm(
const array& a,
const std::string& ord,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (ord == "f" || ord == "fro") {
return l2_norm(a, axis, keepdims, s);
} else if (ord == "nuc") {
throw std::runtime_error(
"[linalg::norm] Nuclear norm not yet implemented.");
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
throw std::invalid_argument(msg.str());
}
}
array norm(
const array& a,
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
if (!axis) {
return norm(flatten(a, s), std::vector<int>{0}, keepdims, s);
}
if (axis.value().size() > 2) {
throw std::invalid_argument(
"[linalg::norm] Received too many axes for norm.");
}
return l2_norm(a, axis.value(), keepdims, s);
}
array norm(
const array& a,
const double ord,
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
std::vector<int> ax;
if (!axis) {
ax.resize(a.ndim());
std::iota(ax.begin(), ax.end(), 0);
} else {
ax = axis.value();
}
if (ax.size() == 1) {
return vector_norm(a, ord, ax, keepdims, s);
} else if (ax.size() == 2) {
return matrix_norm(a, ord, ax, keepdims, s);
} else {
throw std::invalid_argument(
"[linalg::norm] Received too many axes for norm.");
}
}
array norm(
const array& a,
const std::string& ord,
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
bool keepdims /* = false */,
StreamOrDevice s /* = {} */) {
std::vector<int> ax;
if (!axis) {
ax.resize(a.ndim());
std::iota(ax.begin(), ax.end(), 0);
} else {
ax = axis.value();
}
if (ax.size() != 2) {
std::ostringstream msg;
msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices,"
<< " but received " << ax.size() << " axis/axes.";
throw std::invalid_argument(msg.str());
}
return matrix_norm(a, ord, ax, keepdims, s);
}
} // namespace mlx::core::linalg

63
mlx/linalg.h Normal file
View File

@ -0,0 +1,63 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/ops.h"
#include "mlx/stream.h"
namespace mlx::core::linalg {
/**
* Compute vector or matrix norms.
*
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
* - If axis is not provided but ord is, then x must be either 1D or 2D.
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
* for matrices) is computed along the given axes. At most 2 axes can be
* specified.
* - If both axis and ord are provided, then the corresponding matrix or vector
* norm is computed. At most 2 axes can be specified.
*/
array norm(
const array& a,
const double ord,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array norm(
const array& a,
const double ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm(
const array& a,
const std::string& ord,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array norm(
const array& a,
const std::string& ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm(
const array& a,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
return norm(a, std::vector<int>{axis}, keepdims, s);
}
} // namespace mlx::core::linalg

View File

@ -6,6 +6,7 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/fft.h" #include "mlx/fft.h"
#include "mlx/linalg.h"
#include "mlx/ops.h" #include "mlx/ops.h"
#include "mlx/random.h" #include "mlx/random.h"
#include "mlx/stream.h" #include "mlx/stream.h"

View File

@ -11,6 +11,7 @@ pybind11_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
) )
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)

180
python/src/linalg.cpp Normal file
View File

@ -0,0 +1,180 @@
// Copyright © 2023 Apple Inc.
#include <variant>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "mlx/linalg.h"
#include "python/src/load.h"
#include "python/src/utils.h"
namespace py = pybind11;
using namespace py::literals;
using namespace mlx::core;
using namespace mlx::core::linalg;
void init_linalg(py::module_& parent_module) {
py::options options;
options.disable_function_signatures();
auto m = parent_module.def_submodule(
"linalg", "mlx.core.linalg: linear algebra routines.");
m.def(
"norm",
[](const array& a,
const std::variant<std::monostate, int, double, std::string>& ord_,
const std::variant<std::monostate, int, std::vector<int>>& axis_,
const bool keepdims,
const StreamOrDevice stream) {
std::optional<std::vector<int>> axis = std::nullopt;
if (auto pv = std::get_if<int>(&axis_); pv) {
axis = std::vector<int>{*pv};
} else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {
axis = *pv;
}
if (std::holds_alternative<std::monostate>(ord_)) {
return norm(a, axis, keepdims, stream);
} else {
if (auto pv = std::get_if<std::string>(&ord_); pv) {
return norm(a, *pv, axis, keepdims, stream);
}
double ord;
if (auto pv = std::get_if<int>(&ord_); pv) {
ord = *pv;
} else {
ord = std::get<double>(ord_);
}
return norm(a, ord, axis, keepdims, stream);
}
},
"a"_a,
py::pos_only(),
"ord"_a = none,
"axis"_a = none,
"keepdims"_a = false,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
Matrix or vector norm.
This function computes vector or matrix norms depending on the value of
the ``ord`` and ``axis`` parameters.
Args:
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
2-norm of ``a.flatten`` will be returned.
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
along the given ``axis``. Default: ``None``.
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
norms of these matrices are computed. If `axis` is ``None`` then
either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is
2-D) is returned. Default: ``None``.
keepdims (bool, optional): If ``True``, the axes which are normed over are
left in the result as dimensions with size one. Default ``False``.
Returns:
array: The output containing the norm(s).
Notes:
For values of ``ord < 1``, the result is, strictly speaking, not a
mathematical norm, but it may still be useful for various numerical
purposes.
The following norms can be calculated:
===== ============================ ==========================
ord norm for matrices norm for vectors
===== ============================ ==========================
None Frobenius norm 2-norm
'fro' Frobenius norm --
inf max(sum(abs(x), axis=1)) max(abs(x))
-inf min(sum(abs(x), axis=1)) min(abs(x))
0 -- sum(x != 0)
1 max(sum(abs(x), axis=0)) as below
-1 min(sum(abs(x), axis=0)) as below
2 2-norm (largest sing. value) as below
-2 smallest singular value as below
other -- sum(abs(x)**ord)**(1./ord)
===== ============================ ==========================
.. warning::
Nuclear norm and norms based on singular values are not yet implemented.
The Frobenius norm is given by [1]_:
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
The nuclear norm is the sum of the singular values.
Both the Frobenius and nuclear norm orders are only defined for
matrices and raise a ``ValueError`` when ``a.ndim != 2``.
References:
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
Examples:
>>> import mlx.core as mx
>>> from mlx.core import linalg as la
>>> a = mx.arange(9) - 4
>>> a
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
>>> b = a.reshape((3,3))
>>> b
array([[-4, -3, -2],
[-1, 0, 1],
[ 2, 3, 4]], dtype=int32)
>>> la.norm(a)
array(7.74597, dtype=float32)
>>> la.norm(b)
array(7.74597, dtype=float32)
>>> la.norm(b, 'fro')
array(7.74597, dtype=float32)
>>> la.norm(a, float("inf"))
array(4, dtype=float32)
>>> la.norm(b, float("inf"))
array(9, dtype=float32)
>>> la.norm(a, -float("inf"))
array(0, dtype=float32)
>>> la.norm(b, -float("inf"))
array(2, dtype=float32)
>>> la.norm(a, 1)
array(20, dtype=float32)
>>> la.norm(b, 1)
array(7, dtype=float32)
>>> la.norm(a, -1)
array(0, dtype=float32)
>>> la.norm(b, -1)
array(6, dtype=float32)
>>> la.norm(a, 2)
array(7.74597, dtype=float32)
>>> la.norm(a, 3)
array(5.84804, dtype=float32)
>>> la.norm(a, -3)
array(0, dtype=float32)
>>> c = mx.array([[ 1, 2, 3],
... [-1, 1, 4]])
>>> la.norm(c, axis=0)
array([1.41421, 2.23607, 5], dtype=float32)
>>> la.norm(c, axis=1)
array([3.74166, 4.24264], dtype=float32)
>>> la.norm(c, ord=1, axis=1)
array([6, 6], dtype=float32)
>>> m = mx.arange(8).reshape(2,2,2)
>>> la.norm(m, axis=(1,2))
array([3.74166, 11.225], dtype=float32)
>>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
)pbdoc");
}

View File

@ -15,6 +15,7 @@ void init_ops(py::module_&);
void init_transforms(py::module_&); void init_transforms(py::module_&);
void init_random(py::module_&); void init_random(py::module_&);
void init_fft(py::module_&); void init_fft(py::module_&);
void init_linalg(py::module_&);
PYBIND11_MODULE(core, m) { PYBIND11_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon."; m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) {
init_transforms(m); init_transforms(m);
init_random(m); init_random(m);
init_fft(m); init_fft(m);
init_linalg(m);
m.attr("__version__") = TOSTRING(_VERSION_); m.attr("__version__") = TOSTRING(_VERSION_);
} }

View File

@ -0,0 +1,94 @@
# Copyright © 2023 Apple Inc.
import itertools
import math
import unittest
import mlx.core as mx
import mlx_tests
import numpy as np
class TestLinalg(mlx_tests.MLXTestCase):
def test_norm(self):
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
# Test when at least one axis is provided
for num_axes in range(1, len(shape)):
if num_axes == 1:
ords = vector_ords
else:
ords = matrix_ords
for axis in itertools.combinations(range(len(shape)), num_axes):
for keepdims in [True, False]:
for o in ords:
out_np = np.linalg.norm(
x_np, ord=o, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, ord=o, axis=axis, keepdims=keepdims
)
with self.subTest(
shape=shape, ord=o, axis=axis, keepdims=keepdims
):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
# Test only ord provided
for shape in [(3,), (2, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
for o in [None, 1, -1, float("inf"), -float("inf")]:
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
with self.subTest(shape=shape, ord=o, keepdims=keepdims):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
# Test no ord and no axis provided
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
with self.subTest(shape=shape, keepdims=keepdims):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
def test_complex_norm(self):
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_np = np.random.uniform(size=shape).astype(
np.float32
) + 1j * np.random.uniform(size=shape).astype(np.float32)
x_mx = mx.array(x_np)
out_np = np.linalg.norm(x_np)
out_mx = mx.linalg.norm(x_mx)
with self.subTest(shape=shape):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
for num_axes in range(1, len(shape)):
for axis in itertools.combinations(range(len(shape)), num_axes):
out_np = np.linalg.norm(x_np, axis=axis)
out_mx = mx.linalg.norm(x_mx, axis=axis)
with self.subTest(shape=shape, axis=axis):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
x_np = np.random.uniform(size=(4, 4)).astype(
np.float32
) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32)
x_mx = mx.array(x_np)
out_np = np.linalg.norm(x_np, ord="fro")
out_mx = mx.linalg.norm(x_mx, ord="fro")
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
if __name__ == "__main__":
unittest.main()

View File

@ -31,6 +31,7 @@ target_sources(tests PRIVATE
scheduler_tests.cpp scheduler_tests.cpp
utils_tests.cpp utils_tests.cpp
vmap_tests.cpp vmap_tests.cpp
linalg_tests.cpp
${METAL_TEST_SOURCES} ${METAL_TEST_SOURCES}
) )

250
tests/linalg_tests.cpp Normal file
View File

@ -0,0 +1,250 @@
// Copyright © 2023 Apple Inc.
#include "doctest/doctest.h"
#include <cmath>
#include "mlx/mlx.h"
using namespace mlx::core;
using namespace mlx::core::linalg;
TEST_CASE("[mlx.core.linalg.norm] no ord") {
// Zero dimensions
array x(2.0);
CHECK_EQ(norm(x).item<float>(), 2.0f);
CHECK_THROWS(norm(x, 0));
x = array({1, 2, 3});
float expected = std::sqrt(1 + 4 + 9);
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, 0, false).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, -1, false).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, -1, true).ndim(), 1);
CHECK_THROWS(norm(x, 1));
x = reshape(arange(9), {3, 3});
expected =
std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8);
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
CHECK_EQ(
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
CHECK(array_equal(
norm(x, 0, false),
array(
{std::sqrt(0 + 3 * 3 + 6 * 6),
std::sqrt(1 + 4 * 4 + 7 * 7),
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>());
CHECK(allclose(
norm(x, 1, false),
array(
{std::sqrt(0 + 1 + 2 * 2),
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
std::sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>());
x = reshape(arange(18), {2, 3, 3});
CHECK(allclose(
norm(x, 2, false),
array(
{
std::sqrt(0 + 1 + 2 * 2),
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
std::sqrt(6 * 6 + 7 * 7 + 8 * 8),
std::sqrt(9 * 9 + 10 * 10 + 11 * 11),
std::sqrt(12 * 12 + 13 * 13 + 14 * 14),
std::sqrt(15 * 15 + 16 * 16 + 17 * 17),
},
{2, 3}))
.item<bool>());
CHECK(allclose(
norm(x, std::vector<int>{1, 2}, false),
array(
{std::sqrt(
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
8 * 8),
std::sqrt(
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
15 * 15 + 16 * 16 + 17 * 17)},
{2}))
.item<bool>());
CHECK_THROWS(norm(x, std::vector<int>{0, 1, 2}));
}
TEST_CASE("[mlx.core.linalg.norm] double ord") {
CHECK_THROWS(norm(array(0), 2.0));
array x({1, 2, 3});
float expected = std::sqrt(1 + 4 + 9);
CHECK_EQ(norm(x, 2.0).item<float>(), doctest::Approx(expected));
CHECK_EQ(norm(x, 2.0, 0).item<float>(), doctest::Approx(expected));
CHECK_THROWS(norm(x, 2.0, 1));
expected = 1 + 2 + 3;
CHECK_EQ(norm(x, 1.0).item<float>(), doctest::Approx(expected));
expected = 3;
CHECK_EQ(norm(x, 0.0).item<float>(), doctest::Approx(expected));
expected = 3;
CHECK_EQ(
norm(x, std::numeric_limits<double>::infinity()).item<float>(),
doctest::Approx(expected));
expected = 1;
CHECK_EQ(
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
doctest::Approx(expected));
x = reshape(arange(9), {3, 3});
CHECK(allclose(
norm(x, 2.0, 0, false),
array(
{std::sqrt(0 + 3 * 3 + 6 * 6),
std::sqrt(1 + 4 * 4 + 7 * 7),
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
.item<bool>());
CHECK(allclose(
norm(x, 2.0, 1, false),
array(
{sqrt(0 + 1 + 2 * 2),
sqrt(3 * 3 + 4 * 4 + 5 * 5),
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
.item<bool>());
CHECK_EQ(
norm(x, 1.0, std::vector<int>{0, 1}).item<float>(),
doctest::Approx(15.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(21.0));
CHECK_EQ(
norm(x, -1.0, std::vector<int>{0, 1}).item<float>(),
doctest::Approx(9.0));
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
doctest::Approx(3.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
std::vector<int>{1, 1});
CHECK_EQ(
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
doctest::Approx(9.0));
CHECK_EQ(
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
doctest::Approx(15.0));
x = reshape(arange(18), {2, 3, 3});
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
CHECK(allclose(
norm(x, 3.0, 0),
array(
{9.,
10.00333222,
11.02199456,
12.06217728,
13.12502645,
14.2094363,
15.31340617,
16.43469751,
17.57113899},
{3, 3}))
.item<bool>());
CHECK(allclose(
norm(x, 3.0, 2),
array(
{2.08008382,
6.,
10.23127655,
14.5180117,
18.82291607,
23.13593104},
{2, 3}))
.item<bool>());
CHECK(
allclose(
norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
.item<bool>());
CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
.item<bool>());
CHECK(allclose(
norm(x, 1.0, 0),
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
.item<bool>());
CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3}))
.item<bool>());
CHECK(allclose(norm(x, 1.0, std::vector<int>{0, 1}), array({21., 23., 25.}))
.item<bool>());
CHECK(allclose(norm(x, 1.0, std::vector<int>{1, 2}), array({15., 42.}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{0, 1}), array({9., 11., 13.}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9., 36.}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 0}), array({9., 12., 15.}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{2, 1}), array({3, 30}))
.item<bool>());
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
.item<bool>());
}
TEST_CASE("[mlx.core.linalg.norm] string ord") {
array x({1, 2, 3});
CHECK_THROWS(norm(x, "fro"));
x = reshape(arange(9), {3, 3});
CHECK_THROWS(norm(x, "bad ord"));
CHECK_EQ(
norm(x, "f", std::vector<int>{0, 1}).item<float>(),
doctest::Approx(14.2828568570857));
CHECK_EQ(
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
doctest::Approx(14.2828568570857));
x = reshape(arange(18), {2, 3, 3});
CHECK(allclose(
norm(x, "fro", std::vector<int>{0, 1}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(x, "fro", std::vector<int>{1, 2}),
array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(
norm(x, "f", std::vector<int>{0, 1}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(x, "f", std::vector<int>{1, 0}),
array({22.24859546, 24.31049156, 26.43860813}))
.item<bool>());
CHECK(allclose(
norm(x, "f", std::vector<int>{1, 2}),
array({14.28285686, 39.7617907}))
.item<bool>());
CHECK(allclose(
norm(x, "f", std::vector<int>{2, 1}),
array({14.28285686, 39.7617907}))
.item<bool>());
}