From 6b0d30bb85b46b2914db00065f9d054085760165 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Wed, 27 Dec 2023 04:42:04 +0100 Subject: [PATCH] linalg.norm (#187) * implemented vector_norm in cpp added linalg to mlx * implemented vector_norm python binding * renamed vector_norm to norm, implemented norm without provided ord * completed the implementation of the norm * added tests * removed unused import in linalg.cpp * updated python bindings * added some tests for python bindings * handling inf, -inf as numpy does, more extensive tests of compatibility with numpy * added better docs and examples * refactored mlx.linalg.norm bindings * reused existing util for implementation of linalg.norm * more tests * fixed a bug with no ord and axis provided * removed unused imports * some style and API consistency updates to linalg norm * remove unused includes * fix python tests * fixed a bug with frobenius norm of a complex-valued matrix * complex for vector too --------- Co-authored-by: Awni Hannun --- docs/src/index.rst | 1 + docs/src/python/linalg.rst | 11 ++ mlx/CMakeLists.txt | 1 + mlx/linalg.cpp | 175 +++++++++++++++++++++++++ mlx/linalg.h | 63 +++++++++ mlx/mlx.h | 1 + python/src/CMakeLists.txt | 1 + python/src/linalg.cpp | 180 ++++++++++++++++++++++++++ python/src/mlx.cpp | 2 + python/tests/test_linalg.py | 94 ++++++++++++++ tests/CMakeLists.txt | 1 + tests/linalg_tests.cpp | 250 ++++++++++++++++++++++++++++++++++++ 12 files changed, 780 insertions(+) create mode 100644 docs/src/python/linalg.rst create mode 100644 mlx/linalg.cpp create mode 100644 mlx/linalg.h create mode 100644 python/src/linalg.cpp create mode 100644 python/tests/test_linalg.py create mode 100644 tests/linalg_tests.cpp diff --git a/docs/src/index.rst b/docs/src/index.rst index ac4932f10..207238f37 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -57,6 +57,7 @@ are the CPU and GPU. python/random python/transforms python/fft + python/linalg python/nn python/optimizers python/tree_utils diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst new file mode 100644 index 000000000..27746441e --- /dev/null +++ b/docs/src/python/linalg.rst @@ -0,0 +1,11 @@ +.. _linalg: + +Linear Algebra +============== + +.. currentmodule:: mlx.core.linalg + +.. autosummary:: + :toctree: _autosummary + + norm diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index bd28537f1..e004fc3d9 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -14,6 +14,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h ) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp new file mode 100644 index 000000000..7e7264e3f --- /dev/null +++ b/mlx/linalg.cpp @@ -0,0 +1,175 @@ +// Copyright © 2023 Apple Inc. + +#include +#include +#include + +#include "mlx/dtype.h" +#include "mlx/linalg.h" + +namespace mlx::core::linalg { + +Dtype at_least_float(const Dtype& d) { + return is_floating_point(d) ? d : promote_types(d, float32); +} + +inline array l2_norm( + const array& a, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (is_complex(a.dtype())) { + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); + } else { + return sqrt(sum(square(a, s), axis, keepdims, s), s); + } +} + +inline array vector_norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + auto dtype = at_least_float(a.dtype()); + if (ord == 0.0) { + return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s); + } else if (ord == 1.0) { + return astype(sum(abs(a, s), axis, keepdims, s), dtype, s); + } else if (ord == 2.0) { + return l2_norm(a, axis, keepdims, s); + } else if (ord == std::numeric_limits::infinity()) { + return astype(max(abs(a, s), axis, keepdims, s), dtype, s); + } else if (ord == -std::numeric_limits::infinity()) { + return astype(min(abs(a, s), axis, keepdims, s), dtype, s); + } else { + return power( + sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s), + array(1.0 / ord, dtype), + s); + } +} + +inline array matrix_norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + auto dtype = at_least_float(a.dtype()); + auto row_axis = axis[0]; + auto col_axis = axis[1]; + if (ord == -1.0) { + col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0); + return astype( + min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), + dtype, + s); + } else if (ord == 1.0) { + col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0); + return astype( + max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), + dtype, + s); + } else if (ord == std::numeric_limits::infinity()) { + row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); + return astype( + max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), + dtype, + s); + } else if (ord == -std::numeric_limits::infinity()) { + row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); + return astype( + min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), + dtype, + s); + } else if (ord == 2.0 || ord == -2.0) { + throw std::runtime_error( + "[linalg::norm] Singular value norms are not implemented."); + } else { + std::ostringstream msg; + msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm."; + throw std::invalid_argument(msg.str()); + } +} + +inline array matrix_norm( + const array& a, + const std::string& ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (ord == "f" || ord == "fro") { + return l2_norm(a, axis, keepdims, s); + } else if (ord == "nuc") { + throw std::runtime_error( + "[linalg::norm] Nuclear norm not yet implemented."); + } else { + std::ostringstream msg; + msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm."; + throw std::invalid_argument(msg.str()); + } +} + +array norm( + const array& a, + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + if (!axis) { + return norm(flatten(a, s), std::vector{0}, keepdims, s); + } + + if (axis.value().size() > 2) { + throw std::invalid_argument( + "[linalg::norm] Received too many axes for norm."); + } + return l2_norm(a, axis.value(), keepdims, s); +} + +array norm( + const array& a, + const double ord, + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + std::vector ax; + if (!axis) { + ax.resize(a.ndim()); + std::iota(ax.begin(), ax.end(), 0); + } else { + ax = axis.value(); + } + if (ax.size() == 1) { + return vector_norm(a, ord, ax, keepdims, s); + } else if (ax.size() == 2) { + return matrix_norm(a, ord, ax, keepdims, s); + } else { + throw std::invalid_argument( + "[linalg::norm] Received too many axes for norm."); + } +} + +array norm( + const array& a, + const std::string& ord, + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + std::vector ax; + if (!axis) { + ax.resize(a.ndim()); + std::iota(ax.begin(), ax.end(), 0); + } else { + ax = axis.value(); + } + if (ax.size() != 2) { + std::ostringstream msg; + msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices," + << " but received " << ax.size() << " axis/axes."; + throw std::invalid_argument(msg.str()); + } + return matrix_norm(a, ord, ax, keepdims, s); +} + +} // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h new file mode 100644 index 000000000..80e484eb5 --- /dev/null +++ b/mlx/linalg.h @@ -0,0 +1,63 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/ops.h" +#include "mlx/stream.h" + +namespace mlx::core::linalg { + +/** + * Compute vector or matrix norms. + * + * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). + * - If axis is not provided but ord is, then x must be either 1D or 2D. + * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm + * for matrices) is computed along the given axes. At most 2 axes can be + * specified. + * - If both axis and ord are provided, then the corresponding matrix or vector + * norm is computed. At most 2 axes can be specified. + */ +array norm( + const array& a, + const double ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const double ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +array norm( + const array& a, + const std::string& ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const std::string& ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +array norm( + const array& a, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array +norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { + return norm(a, std::vector{axis}, keepdims, s); +} + +} // namespace mlx::core::linalg diff --git a/mlx/mlx.h b/mlx/mlx.h index 102d2dde9..8d785c39f 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -6,6 +6,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/device.h" #include "mlx/fft.h" +#include "mlx/linalg.h" #include "mlx/ops.h" #include "mlx/random.h" #include "mlx/stream.h" diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 5ab8a50bf..1ad9d207d 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -11,6 +11,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp new file mode 100644 index 000000000..ea5474a70 --- /dev/null +++ b/python/src/linalg.cpp @@ -0,0 +1,180 @@ +// Copyright © 2023 Apple Inc. + +#include + +#include +#include + +#include "mlx/linalg.h" + +#include "python/src/load.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace py::literals; + +using namespace mlx::core; +using namespace mlx::core::linalg; + +void init_linalg(py::module_& parent_module) { + py::options options; + options.disable_function_signatures(); + + auto m = parent_module.def_submodule( + "linalg", "mlx.core.linalg: linear algebra routines."); + + m.def( + "norm", + [](const array& a, + const std::variant& ord_, + const std::variant>& axis_, + const bool keepdims, + const StreamOrDevice stream) { + std::optional> axis = std::nullopt; + if (auto pv = std::get_if(&axis_); pv) { + axis = std::vector{*pv}; + } else if (auto pv = std::get_if>(&axis_); pv) { + axis = *pv; + } + + if (std::holds_alternative(ord_)) { + return norm(a, axis, keepdims, stream); + } else { + if (auto pv = std::get_if(&ord_); pv) { + return norm(a, *pv, axis, keepdims, stream); + } + double ord; + if (auto pv = std::get_if(&ord_); pv) { + ord = *pv; + } else { + ord = std::get(ord_); + } + return norm(a, ord, axis, keepdims, stream); + } + }, + "a"_a, + py::pos_only(), + "ord"_a = none, + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array + + Matrix or vector norm. + + This function computes vector or matrix norms depending on the value of + the ``ord`` and ``axis`` parameters. + + Args: + a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D, + unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the + 2-norm of ``a.flatten`` will be returned. + ord (scalar or str, optional): Order of the norm (see table under ``Notes``). + If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed + along the given ``axis``. Default: ``None``. + axis (int or list(int), optional): If ``axis`` is an integer, it specifies the + axis of ``a`` along which to compute the vector norms. If ``axis`` is a + 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix + norms of these matrices are computed. If `axis` is ``None`` then + either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is + 2-D) is returned. Default: ``None``. + keepdims (bool, optional): If ``True``, the axes which are normed over are + left in the result as dimensions with size one. Default ``False``. + + Returns: + array: The output containing the norm(s). + + Notes: + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical norm, but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + .. warning:: + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ``ValueError`` when ``a.ndim != 2``. + + References: + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples: + >>> import mlx.core as mx + >>> from mlx.core import linalg as la + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> la.norm(a) + array(7.74597, dtype=float32) + >>> la.norm(b) + array(7.74597, dtype=float32) + >>> la.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> la.norm(a, float("inf")) + array(4, dtype=float32) + >>> la.norm(b, float("inf")) + array(9, dtype=float32) + >>> la.norm(a, -float("inf")) + array(0, dtype=float32) + >>> la.norm(b, -float("inf")) + array(2, dtype=float32) + >>> la.norm(a, 1) + array(20, dtype=float32) + >>> la.norm(b, 1) + array(7, dtype=float32) + >>> la.norm(a, -1) + array(0, dtype=float32) + >>> la.norm(b, -1) + array(6, dtype=float32) + >>> la.norm(a, 2) + array(7.74597, dtype=float32) + >>> la.norm(a, 3) + array(5.84804, dtype=float32) + >>> la.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> la.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> la.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> la.norm(c, ord=1, axis=1) + array([6, 6], dtype=float32) + >>> m = mx.arange(8).reshape(2,2,2) + >>> la.norm(m, axis=(1,2)) + array([3.74166, 11.225], dtype=float32) + >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); +} diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index ebadf767d..d7cf15751 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -15,6 +15,7 @@ void init_ops(py::module_&); void init_transforms(py::module_&); void init_random(py::module_&); void init_fft(py::module_&); +void init_linalg(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) { init_transforms(m); init_random(m); init_fft(m); + init_linalg(m); m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py new file mode 100644 index 000000000..ac86c1e11 --- /dev/null +++ b/python/tests/test_linalg.py @@ -0,0 +1,94 @@ +# Copyright © 2023 Apple Inc. + +import itertools +import math +import unittest + +import mlx.core as mx +import mlx_tests +import numpy as np + + +class TestLinalg(mlx_tests.MLXTestCase): + def test_norm(self): + vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")] + matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] + + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + # Test when at least one axis is provided + for num_axes in range(1, len(shape)): + if num_axes == 1: + ords = vector_ords + else: + ords = matrix_ords + for axis in itertools.combinations(range(len(shape)), num_axes): + for keepdims in [True, False]: + for o in ords: + out_np = np.linalg.norm( + x_np, ord=o, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, ord=o, axis=axis, keepdims=keepdims + ) + with self.subTest( + shape=shape, ord=o, axis=axis, keepdims=keepdims + ): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) + + # Test only ord provided + for shape in [(3,), (2, 3)]: + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + for o in [None, 1, -1, float("inf"), -float("inf")]: + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) + with self.subTest(shape=shape, ord=o, keepdims=keepdims): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) + + # Test no ord and no axis provided + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) + with self.subTest(shape=shape, keepdims=keepdims): + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + + def test_complex_norm(self): + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_np = np.random.uniform(size=shape).astype( + np.float32 + ) + 1j * np.random.uniform(size=shape).astype(np.float32) + x_mx = mx.array(x_np) + out_np = np.linalg.norm(x_np) + out_mx = mx.linalg.norm(x_mx) + with self.subTest(shape=shape): + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + for num_axes in range(1, len(shape)): + for axis in itertools.combinations(range(len(shape)), num_axes): + out_np = np.linalg.norm(x_np, axis=axis) + out_mx = mx.linalg.norm(x_mx, axis=axis) + with self.subTest(shape=shape, axis=axis): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) + + x_np = np.random.uniform(size=(4, 4)).astype( + np.float32 + ) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32) + x_mx = mx.array(x_np) + out_np = np.linalg.norm(x_np, ord="fro") + out_mx = mx.linalg.norm(x_mx, ord="fro") + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0879aa0f6..dbc499205 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources(tests PRIVATE scheduler_tests.cpp utils_tests.cpp vmap_tests.cpp + linalg_tests.cpp ${METAL_TEST_SOURCES} ) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp new file mode 100644 index 000000000..1d8ee43d9 --- /dev/null +++ b/tests/linalg_tests.cpp @@ -0,0 +1,250 @@ +// Copyright © 2023 Apple Inc. + +#include "doctest/doctest.h" + +#include + +#include "mlx/mlx.h" + +using namespace mlx::core; +using namespace mlx::core::linalg; + +TEST_CASE("[mlx.core.linalg.norm] no ord") { + // Zero dimensions + array x(2.0); + CHECK_EQ(norm(x).item(), 2.0f); + CHECK_THROWS(norm(x, 0)); + + x = array({1, 2, 3}); + float expected = std::sqrt(1 + 4 + 9); + CHECK_EQ(norm(x).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, 0, false).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, -1, false).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, -1, true).ndim(), 1); + CHECK_THROWS(norm(x, 1)); + + x = reshape(arange(9), {3, 3}); + expected = + std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8); + + CHECK_EQ(norm(x).item(), doctest::Approx(expected)); + CHECK_EQ( + norm(x, std::vector{0, 1}).item(), doctest::Approx(expected)); + CHECK(array_equal( + norm(x, 0, false), + array( + {std::sqrt(0 + 3 * 3 + 6 * 6), + std::sqrt(1 + 4 * 4 + 7 * 7), + std::sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + .item()); + CHECK(allclose( + norm(x, 1, false), + array( + {std::sqrt(0 + 1 + 2 * 2), + std::sqrt(3 * 3 + 4 * 4 + 5 * 5), + std::sqrt(6 * 6 + 7 * 7 + 8 * 8)})) + .item()); + + x = reshape(arange(18), {2, 3, 3}); + CHECK(allclose( + norm(x, 2, false), + array( + { + std::sqrt(0 + 1 + 2 * 2), + std::sqrt(3 * 3 + 4 * 4 + 5 * 5), + std::sqrt(6 * 6 + 7 * 7 + 8 * 8), + std::sqrt(9 * 9 + 10 * 10 + 11 * 11), + std::sqrt(12 * 12 + 13 * 13 + 14 * 14), + std::sqrt(15 * 15 + 16 * 16 + 17 * 17), + }, + {2, 3})) + .item()); + CHECK(allclose( + norm(x, std::vector{1, 2}, false), + array( + {std::sqrt( + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + + 8 * 8), + std::sqrt( + 9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 + + 15 * 15 + 16 * 16 + 17 * 17)}, + {2})) + .item()); + CHECK_THROWS(norm(x, std::vector{0, 1, 2})); +} + +TEST_CASE("[mlx.core.linalg.norm] double ord") { + CHECK_THROWS(norm(array(0), 2.0)); + + array x({1, 2, 3}); + + float expected = std::sqrt(1 + 4 + 9); + CHECK_EQ(norm(x, 2.0).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, 2.0, 0).item(), doctest::Approx(expected)); + CHECK_THROWS(norm(x, 2.0, 1)); + + expected = 1 + 2 + 3; + CHECK_EQ(norm(x, 1.0).item(), doctest::Approx(expected)); + + expected = 3; + CHECK_EQ(norm(x, 0.0).item(), doctest::Approx(expected)); + + expected = 3; + CHECK_EQ( + norm(x, std::numeric_limits::infinity()).item(), + doctest::Approx(expected)); + + expected = 1; + CHECK_EQ( + norm(x, -std::numeric_limits::infinity()).item(), + doctest::Approx(expected)); + + x = reshape(arange(9), {3, 3}); + + CHECK(allclose( + norm(x, 2.0, 0, false), + array( + {std::sqrt(0 + 3 * 3 + 6 * 6), + std::sqrt(1 + 4 * 4 + 7 * 7), + std::sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + .item()); + CHECK(allclose( + norm(x, 2.0, 1, false), + array( + {sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8)})) + .item()); + + CHECK_EQ( + norm(x, 1.0, std::vector{0, 1}).item(), + doctest::Approx(15.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{1, 0}).item(), + doctest::Approx(21.0)); + CHECK_EQ( + norm(x, -1.0, std::vector{0, 1}).item(), + doctest::Approx(9.0)); + CHECK_EQ( + norm(x, -1.0, std::vector{1, 0}).item(), + doctest::Approx(3.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{0, 1}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, 1.0, std::vector{1, 0}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, -1.0, std::vector{0, 1}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, -1.0, std::vector{1, 0}, true).shape(), + std::vector{1, 1}); + + CHECK_EQ( + norm(x, -1.0, std::vector{-2, -1}, false).item(), + doctest::Approx(9.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{-2, -1}, false).item(), + doctest::Approx(15.0)); + + x = reshape(arange(18), {2, 3, 3}); + CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2})); + CHECK(allclose( + norm(x, 3.0, 0), + array( + {9., + 10.00333222, + 11.02199456, + 12.06217728, + 13.12502645, + 14.2094363, + 15.31340617, + 16.43469751, + 17.57113899}, + {3, 3})) + .item()); + CHECK(allclose( + norm(x, 3.0, 2), + array( + {2.08008382, + 6., + 10.23127655, + 14.5180117, + 18.82291607, + 23.13593104}, + {2, 3})) + .item()); + CHECK( + allclose( + norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3})) + .item()); + CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK(allclose( + norm(x, 1.0, 0), + array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3})) + .item()); + CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3})) + .item()); + CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3})) + .item()); + + CHECK(allclose(norm(x, 1.0, std::vector{0, 1}), array({21., 23., 25.})) + .item()); + CHECK(allclose(norm(x, 1.0, std::vector{1, 2}), array({15., 42.})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{0, 1}), array({9., 11., 13.})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9., 36.})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{1, 0}), array({9., 12., 15.})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{2, 1}), array({3, 30})) + .item()); + CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9, 36})) + .item()); +} + +TEST_CASE("[mlx.core.linalg.norm] string ord") { + array x({1, 2, 3}); + CHECK_THROWS(norm(x, "fro")); + + x = reshape(arange(9), {3, 3}); + CHECK_THROWS(norm(x, "bad ord")); + + CHECK_EQ( + norm(x, "f", std::vector{0, 1}).item(), + doctest::Approx(14.2828568570857)); + CHECK_EQ( + norm(x, "fro", std::vector{0, 1}).item(), + doctest::Approx(14.2828568570857)); + + x = reshape(arange(18), {2, 3, 3}); + CHECK(allclose( + norm(x, "fro", std::vector{0, 1}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(x, "fro", std::vector{1, 2}), + array({14.28285686, 39.7617907})) + .item()); + CHECK(allclose( + norm(x, "f", std::vector{0, 1}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(x, "f", std::vector{1, 0}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(x, "f", std::vector{1, 2}), + array({14.28285686, 39.7617907})) + .item()); + CHECK(allclose( + norm(x, "f", std::vector{2, 1}), + array({14.28285686, 39.7617907})) + .item()); +}