From f7cea9563df25a7b6d63a78f0b6cec99a4fa2e5c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 26 Dec 2023 10:54:59 -0800 Subject: [PATCH] some style and API consistency updates to linalg norm --- docs/src/index.rst | 1 + docs/src/python/linalg.rst | 4 +- mlx/linalg.cpp | 196 ++++++++++---------- mlx/linalg.h | 42 ++++- mlx/utils.cpp | 7 - mlx/utils.h | 16 -- python/src/fft.cpp | 1 - python/src/linalg.cpp | 293 ++++++++++++++---------------- python/src/utils.h | 15 ++ tests/linalg_tests.cpp | 360 ++++++++++++++++--------------------- 10 files changed, 437 insertions(+), 498 deletions(-) diff --git a/docs/src/index.rst b/docs/src/index.rst index ac4932f10..207238f37 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -57,6 +57,7 @@ are the CPU and GPU. python/random python/transforms python/fft + python/linalg python/nn python/optimizers python/tree_utils diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 6c9daa100..27746441e 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -1,11 +1,11 @@ .. _linalg: Linear Algebra -===== +============== .. currentmodule:: mlx.core.linalg .. autosummary:: :toctree: _autosummary - norm \ No newline at end of file + norm diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 33fa92083..61c9e8537 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -1,47 +1,42 @@ // Copyright © 2023 Apple Inc. -#include -#include +#include +#include #include #include "mlx/array.h" #include "mlx/linalg.h" #include "mlx/ops.h" -#include "mlx/utils.h" namespace mlx::core::linalg { +Dtype at_least_float(const Dtype& d) { + return is_floating_point(d) ? d : promote_types(d, float32); +} + inline array vector_norm( const array& a, const double ord, const std::vector& axis, bool keepdims, StreamOrDevice s) { - if (ord == 0.0) - return sum(a != 0, axis, keepdims, s); - else if (ord == 1.0) - return sum(abs(a, s), axis, keepdims, s); - else if (ord == 2.0) - return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); - else + auto dtype = at_least_float(a.dtype()); + if (ord == 0.0) { + return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s); + } else if (ord == 1.0) { + return astype(sum(abs(a, s), axis, keepdims, s), dtype, s); + } else if (ord == 2.0) { + return sqrt(sum(square(a, s), axis, keepdims, s), s); + } else if (ord == std::numeric_limits::infinity()) { + return astype(max(abs(a, s), axis, keepdims, s), dtype, s); + } else if (ord == -std::numeric_limits::infinity()) { + return astype(min(abs(a, s), axis, keepdims, s), dtype, s); + } else { return power( - sum(power(abs(a, s), array(ord), s), axis, keepdims, s), - array(1.0 / ord)); -} - -inline array vector_norm( - const array& a, - const std::string& ord, - const std::vector& axis, - bool keepdims, - StreamOrDevice s) { - if (ord == "inf") - return max(abs(a, s), axis, keepdims, s); - else if (ord == "-inf") - return min(abs(a, s), axis, keepdims, s); - std::ostringstream error_stream; - error_stream << "Invalid ord value " << ord; - throw std::invalid_argument(error_stream.str()); + sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s), + array(1.0 / ord, dtype), + s); + } } inline array matrix_norm( @@ -50,19 +45,30 @@ inline array matrix_norm( const std::vector& axis, bool keepdims, StreamOrDevice s) { + auto dtype = at_least_float(a.dtype()); auto row_axis = axis[0]; auto col_axis = axis[1]; - if (!keepdims && col_axis > row_axis) + if (!keepdims && col_axis > row_axis && col_axis > 0) { col_axis -= 1; - if (ord == -1.0) - return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); - if (ord == 1.0) - return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); - if (ord == 2.0 || ord == -2.0) - throw std::logic_error("Singular value norms are not implemented."); - std::ostringstream error_stream; - error_stream << "Invalid ord value " << ord << " for matrix norm"; - throw std::invalid_argument(error_stream.str()); + } + if (ord == -1.0) { + return astype( + min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), + dtype, + s); + } else if (ord == 1.0) { + return astype( + max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), + dtype, + s); + } else if (ord == 2.0 || ord == -2.0) { + throw std::runtime_error( + "[linalg::norm] Singular value norms are not implemented."); + } else { + std::ostringstream msg; + msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm"; + throw std::invalid_argument(msg.str()); + } } inline array matrix_norm( @@ -71,85 +77,77 @@ inline array matrix_norm( const std::vector& axis, bool keepdims, StreamOrDevice s) { - if (ord == "f" || ord == "fro") + if (ord == "f" || ord == "fro") { return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); - else if (ord == "inf") - return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s); - else if (ord == "-inf") - return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s); - if (ord == "nuc") - throw std::logic_error("Nuclear norm is not implemented."); - std::ostringstream error_stream; - error_stream << "Invalid ord value " << ord << " for matrix norm"; - throw std::invalid_argument(error_stream.str()); + } else if (ord == "nuc") { + throw std::runtime_error( + "[linalg::norm] Nuclear norm not yet implemented."); + } else { + std::ostringstream msg; + msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm"; + throw std::invalid_argument(msg.str()); + } } array norm( const array& a, - const std::vector& axis, - bool keepdims, - StreamOrDevice s) { - auto num_axes = axis.size(); + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + if (!axis) { + return norm(flatten(a, s), std::vector{0}, keepdims, s); + } - if (num_axes == 0 || num_axes == 1 || num_axes == 2) - return sqrt( - sum(abs(a, s) * abs(a, s), - num_axes ? axis : get_reduce_axes({}, a.ndim()), - keepdims, - s), - s); - - std::ostringstream error_stream; - error_stream << "Invalid axis values " << axis; - throw std::invalid_argument(error_stream.str()); + if (axis.value().size() > 2) { + throw std::invalid_argument( + "[linalg::norm] Received too many axes for norm"); + } + return sqrt(sum(square(a, s), axis.value(), keepdims, s), s); } array norm( const array& a, const double ord, - const std::vector& axis, - bool keepdims, - StreamOrDevice s) { - std::vector ax = axis; - - if (axis.empty()) - ax = get_reduce_axes({}, a.ndim()); - else - ax = normalize_axes(ax, a.ndim()); - - auto num_axes = ax.size(); - if (num_axes == 1) + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + std::vector ax; + if (!axis) { + ax.resize(a.ndim()); + std::iota(ax.begin(), ax.end(), 0); + } else { + ax = axis.value(); + } + if (ax.size() == 1) { return vector_norm(a, ord, ax, keepdims, s); - else if (num_axes == 2) + } else if (ax.size() == 2) { return matrix_norm(a, ord, ax, keepdims, s); - - std::ostringstream error_stream; - error_stream << "Invalid axis values " << ax; - throw std::invalid_argument(error_stream.str()); + } else { + throw std::invalid_argument( + "[linalg::norm] Received too many axes for norm"); + } } array norm( const array& a, const std::string& ord, - const std::vector& axis, - bool keepdims, - StreamOrDevice s) { - std::vector ax = axis; - - if (axis.empty()) - ax = get_reduce_axes({}, a.ndim()); - else - ax = normalize_axes(ax, a.ndim()); - - auto num_axes = ax.size(); - if (num_axes == 1) - return vector_norm(a, ord, ax, keepdims, s); - else if (num_axes == 2) - return matrix_norm(a, ord, ax, keepdims, s); - - std::ostringstream error_stream; - error_stream << "Invalid axis values " << ax; - throw std::invalid_argument(error_stream.str()); + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + std::vector ax; + if (!axis) { + ax.resize(a.ndim()); + std::iota(ax.begin(), ax.end(), 0); + } else { + ax = axis.value(); + } + if (ax.size() != 2) { + std::ostringstream msg; + msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices," + << " but received " << ax.size() << " axis/axes."; + throw std::invalid_argument(msg.str()); + } + return matrix_norm(a, ord, ax, keepdims, s); } -} // namespace mlx::core::linalg \ No newline at end of file +} // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index d77ada477..bf3b5e78c 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -2,27 +2,61 @@ #pragma once +#include + #include "array.h" #include "device.h" #include "ops.h" #include "stream.h" namespace mlx::core::linalg { + +/* + * Compute vector or matrix norms. + * + * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). + * - If axis is not provided but ord is, then x must be either 1D or 2D. + * - If axis is provided, but ord is not, then the 2-norm is computed along the + * given axes. At most 2 axes can be specified. + * - If both axis and ord are provided, then the corresponding matrix of vector + * norm is computed. At most 2 axes can be specified. + */ array norm( const array& a, const double ord, - const std::vector& axis = {}, + const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); +inline array norm( + const array& a, + const double ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} array norm( const array& a, const std::string& ord, - const std::vector& axis = {}, + const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); +inline array norm( + const array& a, + const std::string& ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} array norm( const array& a, - const std::vector& axis = {}, + const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); -} // namespace mlx::core::linalg \ No newline at end of file +inline array +norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { + return norm(a, std::vector{axis}, keepdims, s); +} + +} // namespace mlx::core::linalg diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 932217ad4..1fbc67c8e 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,6 +1,5 @@ // Copyright © 2023 Apple Inc. -#include #include #include @@ -74,12 +73,6 @@ int normalize_axis(int axis, int ndim) { } return axis; } -std::vector normalize_axes(const std::vector& axes, int ndim) { - std::vector canonical; - for (int ax : axes) - canonical.push_back(normalize_axis(ax, ndim)); - return canonical; -} std::ostream& operator<<(std::ostream& os, const Device& d) { os << "Device("; diff --git a/mlx/utils.h b/mlx/utils.h index 0b0ae9e93..823b4c872 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -2,7 +2,6 @@ #pragma once -#include #include "array.h" #include "device.h" #include "dtype.h" @@ -25,7 +24,6 @@ bool is_same_shape(const std::vector& arrays); * https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html */ int normalize_axis(int axis, int ndim); -std::vector normalize_axes(const std::vector& axes, int ndim); std::ostream& operator<<(std::ostream& os, const Device& d); std::ostream& operator<<(std::ostream& os, const Stream& s); @@ -43,18 +41,4 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } - -using IntOrVec = std::variant>; -inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { - std::vector axes; - if (std::holds_alternative(v)) { - axes.resize(dims); - std::iota(axes.begin(), axes.end(), 0); - } else if (auto pv = std::get_if(&v); pv) { - axes.push_back(*pv); - } else { - axes = std::get>(v); - } - return axes; -} } // namespace mlx::core diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 6b3739ae6..42ad37633 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -7,7 +7,6 @@ #include "mlx/fft.h" #include "mlx/ops.h" -#include "mlx/utils.h" namespace py = pybind11; using namespace py::literals; diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index de5ccfcf3..7bd186d51 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -26,193 +26,164 @@ using namespace mlx::core; using namespace mlx::core::linalg; void init_linalg(py::module_& parent_module) { + py::options options; + options.disable_function_signatures(); + auto m = parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); m.def( "norm", [](const array& a, - const std::variant& ord, - const std::variant>& axis, + const std::variant& ord_, + const std::variant>& axis_, const bool keepdims, const StreamOrDevice stream) { - return std::visit( - overloaded{ - [&](const double p) { - if (std::isinf((float)p) || std::isinf(p)) { - if (p > 0) { - return norm( - a, - "inf", - std::holds_alternative(axis) - ? std::vector() - : get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - } - return norm( - a, - "-inf", - get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - } - return norm( - a, - p, - std::holds_alternative(axis) - ? std::vector() - : get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - }, - [&](const std::string& p) { - return norm( - a, - p, - std::holds_alternative(axis) - ? std::vector() - : get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - }, - [&](const std::monostate _) { - return norm( - a, - std::holds_alternative(axis) - ? std::vector() - : get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - }}, - ord); + std::optional> axis = std::nullopt; + if (auto pv = std::get_if(&axis_); pv) { + axis = std::vector{*pv}; + } else if (auto pv = std::get_if>(&axis_); pv) { + axis = *pv; + } + + if (std::holds_alternative(ord_)) { + return norm(a, axis, keepdims, stream); + } else { + if (auto pv = std::get_if(&ord_); pv) { + return norm(a, *pv, axis, keepdims, stream); + } + double ord; + if (auto pv = std::get_if(&ord_); pv) { + ord = *pv; + } else { + ord = std::get(ord_); + } + return norm(a, ord, axis, keepdims, stream); + } }, "a"_a, + py::pos_only(), "ord"_a = none, "axis"_a = none, "keepdims"_a = false, + py::kw_only(), "stream"_a = none, R"pbdoc( - Matrix or vector norm. + norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. + Matrix or vector norm. - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. + This function computes vector or matrix norms depending on the value of + the ``ord`` and ``axis`` parameters. - Returns - ------- - n : array - Norm of the matrix or vector(s). + Args: + a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D, + unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the + 2-norm of ``a.flatten`` will be returned. + ord (scalar or str, optional): Order of the norm (see table under ``Notes``). + If ``None``, the 2-norm will be computed along the given ``axis``. + Default: ``None``. + axis (int or list(int), optional): If ``axis`` is an integer, it specifies the + axis of ``a`` along which to compute the vector norms. If ``axis`` is a + 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix + norms of these matrices are computed. If `axis` is ``None`` then + either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is + 2-D) is returned. Default: ``None``. + keepdims (bool, optional): If ``True``, the axes which are normed over are + left in the result as dimensions with size one. Default ``False``. - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. + Returns: + array: The output containing the norm(s). - The following norms can be calculated: + Notes: + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical norm, but it may still be useful for various numerical + purposes. - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== + The following norms can be calculated: - Nuclear norm and norms based on singular values are not yet implemented. + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== - The Frobenius norm is given by [1]_: + .. warning:: + Nuclear norm and norms based on singular values are not yet implemented. - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + The Frobenius norm is given by [1]_: - The nuclear norm is the sum of the singular values. + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. + The nuclear norm is the sum of the singular values. - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ``ValueError`` when ``a.ndim != 2``. - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - >>> LA.norm(m, axis=(1,2)) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); + References: + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples: + >>> import mlx.core as mx + >>> from mlx.core import linalg as la + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> la.norm(a) + array(7.74597, dtype=float32) + >>> la.norm(b) + array(7.74597, dtype=float32) + >>> la.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> la.norm(a, float("inf")) + array(4, dtype=float32) + >>> la.norm(b, float("inf")) + array(9, dtype=float32) + >>> la.norm(a, -float("inf")) + array(0, dtype=float32) + >>> la.norm(b, -float("inf")) + array(2, dtype=float32) + >>> la.norm(a, 1) + array(20, dtype=float32) + >>> la.norm(b, 1) + array(7, dtype=float32) + >>> la.norm(a, -1) + array(0, dtype=float32) + >>> la.norm(b, -1) + array(6, dtype=float32) + >>> la.norm(a, 2) + array(7.74597, dtype=float32) + >>> la.norm(a, 3) + array(5.84804, dtype=float32) + >>> la.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> la.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> la.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> la.norm(c, ord=1, axis=1) + array([6, 6], dtype=float32) + >>> m = mx.arange(8).reshape(2,2,2) + >>> la.norm(m, axis=(1,2)) + array([3.74166, 11.225], dtype=float32) + >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); } diff --git a/python/src/utils.h b/python/src/utils.h index 9751b2d6e..5ac878979 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. #pragma once +#include #include #include @@ -13,10 +14,24 @@ namespace py = pybind11; using namespace mlx::core; +using IntOrVec = std::variant>; using ScalarOrArray = std:: variant, py::object>; static constexpr std::monostate none{}; +inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { + std::vector axes; + if (std::holds_alternative(v)) { + axes.resize(dims); + std::iota(axes.begin(), axes.end(), 0); + } else if (auto pv = std::get_if(&v); pv) { + axes.push_back(*pv); + } else { + axes = std::get>(v); + } + return axes; +} + inline array to_array_with_accessor(py::object obj) { if (py::hasattr(obj, "__mlx_array__")) { return obj.attr("__mlx_array__")().cast(); diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 9841f03bf..1d8ee43d9 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -3,170 +3,155 @@ #include "doctest/doctest.h" #include -#include -#include "mlx/linalg.h" + #include "mlx/mlx.h" using namespace mlx::core; using namespace mlx::core::linalg; TEST_CASE("[mlx.core.linalg.norm] no ord") { - array arr_one_d({1, 2, 3}); - array arr_two_d = reshape(arange(9), {3, 3}); - array arr_three_d = reshape(arange(18), {2, 3, 3}); + // Zero dimensions + array x(2.0); + CHECK_EQ(norm(x).item(), 2.0f); + CHECK_THROWS(norm(x, 0)); - CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item()); - CHECK(array_equal(norm(arr_one_d, {0}, false), array(sqrt(1 + 4 + 9))) - .item()); + x = array({1, 2, 3}); + float expected = std::sqrt(1 + 4 + 9); + CHECK_EQ(norm(x).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, 0, false).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, -1, false).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, -1, true).ndim(), 1); + CHECK_THROWS(norm(x, 1)); + + x = reshape(arange(9), {3, 3}); + expected = + std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8); + + CHECK_EQ(norm(x).item(), doctest::Approx(expected)); + CHECK_EQ( + norm(x, std::vector{0, 1}).item(), doctest::Approx(expected)); CHECK(array_equal( - norm(arr_two_d, {}, false), - array(sqrt( - 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) - .item()); - CHECK(array_equal( - norm(arr_two_d, {0}, false), + norm(x, 0, false), array( - {sqrt(0 + 3 * 3 + 6 * 6), - sqrt(1 + 4 * 4 + 7 * 7), - sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + {std::sqrt(0 + 3 * 3 + 6 * 6), + std::sqrt(1 + 4 * 4 + 7 * 7), + std::sqrt(2 * 2 + 5 * 5 + 8 * 8)})) .item()); - CHECK(array_equal( - norm(arr_two_d, {1}, false), + CHECK(allclose( + norm(x, 1, false), array( - {sqrt(0 + 1 + 2 * 2), - sqrt(3 * 3 + 4 * 4 + 5 * 5), - sqrt(6 * 6 + 7 * 7 + 8 * 8)})) + {std::sqrt(0 + 1 + 2 * 2), + std::sqrt(3 * 3 + 4 * 4 + 5 * 5), + std::sqrt(6 * 6 + 7 * 7 + 8 * 8)})) .item()); - CHECK(array_equal( - norm(arr_two_d, {0, 1}, false), - array(sqrt( - 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) - .item()); - CHECK(array_equal( - norm(arr_three_d, {2}, false), + + x = reshape(arange(18), {2, 3, 3}); + CHECK(allclose( + norm(x, 2, false), array( { - sqrt(0 + 1 + 2 * 2), - sqrt(3 * 3 + 4 * 4 + 5 * 5), - sqrt(6 * 6 + 7 * 7 + 8 * 8), - sqrt(9 * 9 + 10 * 10 + 11 * 11), - sqrt(12 * 12 + 13 * 13 + 14 * 14), - sqrt(15 * 15 + 16 * 16 + 17 * 17), + std::sqrt(0 + 1 + 2 * 2), + std::sqrt(3 * 3 + 4 * 4 + 5 * 5), + std::sqrt(6 * 6 + 7 * 7 + 8 * 8), + std::sqrt(9 * 9 + 10 * 10 + 11 * 11), + std::sqrt(12 * 12 + 13 * 13 + 14 * 14), + std::sqrt(15 * 15 + 16 * 16 + 17 * 17), }, {2, 3})) .item()); - CHECK(array_equal( - norm(arr_three_d, {1}, false), + CHECK(allclose( + norm(x, std::vector{1, 2}, false), array( - { - sqrt(0 + 3 * 3 + 6 * 6), - sqrt(1 + 4 * 4 + 7 * 7), - sqrt(2 * 2 + 5 * 5 + 8 * 8), - sqrt(9 * 9 + 12 * 12 + 15 * 15), - sqrt(10 * 10 + 13 * 13 + 16 * 16), - sqrt(11 * 11 + 14 * 14 + 17 * 17), - }, - {2, 3})) - .item()); - CHECK(array_equal( - norm(arr_three_d, {0}, false), - array( - { - sqrt(0 + 9 * 9), - sqrt(1 + 10 * 10), - sqrt(2 * 2 + 11 * 11), - sqrt(3 * 3 + 12 * 12), - sqrt(4 * 4 + 13 * 13), - sqrt(5 * 5 + 14 * 14), - sqrt(6 * 6 + 15 * 15), - sqrt(7 * 7 + 16 * 16), - sqrt(8 * 8 + 17 * 17), - }, - {3, 3})) - .item()); - CHECK(array_equal( - norm(arr_three_d, {1, 2}, false), - array( - {sqrt( + {std::sqrt( 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8), - sqrt( + std::sqrt( 9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 + 15 * 15 + 16 * 16 + 17 * 17)}, {2})) .item()); + CHECK_THROWS(norm(x, std::vector{0, 1, 2})); } TEST_CASE("[mlx.core.linalg.norm] double ord") { - array arr_one_d({1, 2, 3}); - array arr_two_d = reshape(arange(9), {3, 3}); - array arr_three_d = reshape(arange(18), {2, 3, 3}); + CHECK_THROWS(norm(array(0), 2.0)); - CHECK(array_equal(norm(arr_one_d, 2.0), array(sqrt(1 + 4 + 9))).item()); - CHECK(array_equal(norm(arr_one_d, 1.0), array(1 + 2 + 3)).item()); - CHECK(array_equal(norm(arr_one_d, 0.0), array(3)).item()); + array x({1, 2, 3}); - CHECK(array_equal(norm(arr_one_d, 2.0, {0}, false), array(sqrt(1 + 4 + 9))) - .item()); - CHECK(array_equal( - norm(arr_two_d, 2.0, {0}, false), + float expected = std::sqrt(1 + 4 + 9); + CHECK_EQ(norm(x, 2.0).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, 2.0, 0).item(), doctest::Approx(expected)); + CHECK_THROWS(norm(x, 2.0, 1)); + + expected = 1 + 2 + 3; + CHECK_EQ(norm(x, 1.0).item(), doctest::Approx(expected)); + + expected = 3; + CHECK_EQ(norm(x, 0.0).item(), doctest::Approx(expected)); + + expected = 3; + CHECK_EQ( + norm(x, std::numeric_limits::infinity()).item(), + doctest::Approx(expected)); + + expected = 1; + CHECK_EQ( + norm(x, -std::numeric_limits::infinity()).item(), + doctest::Approx(expected)); + + x = reshape(arange(9), {3, 3}); + + CHECK(allclose( + norm(x, 2.0, 0, false), array( - {sqrt(0 + 3 * 3 + 6 * 6), - sqrt(1 + 4 * 4 + 7 * 7), - sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + {std::sqrt(0 + 3 * 3 + 6 * 6), + std::sqrt(1 + 4 * 4 + 7 * 7), + std::sqrt(2 * 2 + 5 * 5 + 8 * 8)})) .item()); - CHECK(array_equal( - norm(arr_two_d, 2.0, {1}, false), + CHECK(allclose( + norm(x, 2.0, 1, false), array( {sqrt(0 + 1 + 2 * 2), sqrt(3 * 3 + 4 * 4 + 5 * 5), sqrt(6 * 6 + 7 * 7 + 8 * 8)})) .item()); - CHECK(array_equal( - norm(arr_three_d, 2.0, {2}, false), - array( - { - sqrt(0 + 1 + 2 * 2), - sqrt(3 * 3 + 4 * 4 + 5 * 5), - sqrt(6 * 6 + 7 * 7 + 8 * 8), - sqrt(9 * 9 + 10 * 10 + 11 * 11), - sqrt(12 * 12 + 13 * 13 + 14 * 14), - sqrt(15 * 15 + 16 * 16 + 17 * 17), - }, - {2, 3})) - .item()); - CHECK(array_equal( - norm(arr_three_d, 2.0, {1}, false), - array( - { - sqrt(0 + 3 * 3 + 6 * 6), - sqrt(1 + 4 * 4 + 7 * 7), - sqrt(2 * 2 + 5 * 5 + 8 * 8), - sqrt(9 * 9 + 12 * 12 + 15 * 15), - sqrt(10 * 10 + 13 * 13 + 16 * 16), - sqrt(11 * 11 + 14 * 14 + 17 * 17), - }, - {2, 3})) - .item()); - CHECK(array_equal( - norm(arr_three_d, 2.0, {0}, false), - array( - { - sqrt(0 + 9 * 9), - sqrt(1 + 10 * 10), - sqrt(2 * 2 + 11 * 11), - sqrt(3 * 3 + 12 * 12), - sqrt(4 * 4 + 13 * 13), - sqrt(5 * 5 + 14 * 14), - sqrt(6 * 6 + 15 * 15), - sqrt(7 * 7 + 16 * 16), - sqrt(8 * 8 + 17 * 17), - }, - {3, 3})) - .item()); + CHECK_EQ( + norm(x, 1.0, std::vector{0, 1}).item(), + doctest::Approx(15.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{1, 0}).item(), + doctest::Approx(21.0)); + CHECK_EQ( + norm(x, -1.0, std::vector{0, 1}).item(), + doctest::Approx(9.0)); + CHECK_EQ( + norm(x, -1.0, std::vector{1, 0}).item(), + doctest::Approx(3.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{0, 1}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, 1.0, std::vector{1, 0}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, -1.0, std::vector{0, 1}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, -1.0, std::vector{1, 0}, true).shape(), + std::vector{1, 1}); + + CHECK_EQ( + norm(x, -1.0, std::vector{-2, -1}, false).item(), + doctest::Approx(9.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{-2, -1}, false).item(), + doctest::Approx(15.0)); + + x = reshape(arange(18), {2, 3, 3}); + CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2})); CHECK(allclose( - norm(arr_three_d, 3.0, {0}), + norm(x, 3.0, 0), array( {9., 10.00333222, @@ -179,15 +164,8 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { 17.57113899}, {3, 3})) .item()); - CHECK( - allclose( - norm(arr_three_d, 3.0, {1}), - array( - {6.24025147, 7.41685954, 8.6401226, 18., 19.39257164, 20.7915893}, - {2, 3})) - .item()); CHECK(allclose( - norm(arr_three_d, 3.0, {2}), + norm(x, 3.0, 2), array( {2.08008382, 6., @@ -197,110 +175,76 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { 23.13593104}, {2, 3})) .item()); - CHECK(allclose( - norm(arr_three_d, 0.0, {0}), - array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3})) + CHECK( + allclose( + norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3})) + .item()); + CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3})) .item()); - CHECK( - allclose( - norm(arr_three_d, 0.0, {1}), array({2., 3., 3., 3., 3., 3.}, {2, 3})) - .item()); - CHECK( - allclose( - norm(arr_three_d, 0.0, {2}), array({2., 3., 3., 3., 3., 3.}, {2, 3})) - .item()); CHECK(allclose( - norm(arr_three_d, 1.0, {0}), + norm(x, 1.0, 0), array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3})) .item()); - CHECK(allclose( - norm(arr_three_d, 1.0, {1}), - array({9., 12., 15., 36., 39., 42.}, {2, 3})) + CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3})) .item()); - CHECK(allclose( - norm(arr_three_d, 1.0, {2}), - array({3., 12., 21., 30., 39., 48.}, {2, 3})) + CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3})) .item()); - CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}), array({15.0})).item()); - CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}), array({21.0})).item()); - CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}), array({9.0})).item()); - CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}), array({3.0})).item()); - - CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}, true), array({15.0}, {1, 1})) + CHECK(allclose(norm(x, 1.0, std::vector{0, 1}), array({21., 23., 25.})) .item()); - CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}, true), array({21.0}, {1, 1})) + CHECK(allclose(norm(x, 1.0, std::vector{1, 2}), array({15., 42.})) .item()); - CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}, true), array({9.0}, {1, 1})) + CHECK(allclose(norm(x, -1.0, std::vector{0, 1}), array({9., 11., 13.})) .item()); - CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}, true), array({3.0}, {1, 1})) + CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9., 36.})) .item()); - - CHECK(array_equal(norm(arr_two_d, -1.0, {-2, -1}, false), array(9.0)) + CHECK(allclose(norm(x, -1.0, std::vector{1, 0}), array({9., 12., 15.})) .item()); - CHECK(array_equal(norm(arr_two_d, 1.0, {-2, -1}, false), array(15.0)) + CHECK(allclose(norm(x, -1.0, std::vector{2, 1}), array({3, 30})) .item()); - // - CHECK(allclose(norm(arr_three_d, 1.0, {0, 1}), array({21., 23., 25.})) + CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9, 36})) .item()); - CHECK( - allclose(norm(arr_three_d, 1.0, {1, 2}), array({15., 42.})).item()); - CHECK(allclose(norm(arr_three_d, -1.0, {0, 1}), array({9., 11., 13.})) - .item()); - CHECK( - allclose(norm(arr_three_d, -1.0, {1, 2}), array({9., 36.})).item()); - CHECK(allclose(norm(arr_three_d, -1.0, {1, 0}), array({9., 12., 15.})) - .item()); - CHECK(allclose(norm(arr_three_d, -1.0, {2, 1}), array({3, 30})).item()); - CHECK(allclose(norm(arr_three_d, -1.0, {1, 2}), array({9, 36})).item()); } TEST_CASE("[mlx.core.linalg.norm] string ord") { - array arr_one_d({1, 2, 3}); - array arr_two_d = reshape(arange(9), {3, 3}); - array arr_three_d = reshape(arange(18), {2, 3, 3}); + array x({1, 2, 3}); + CHECK_THROWS(norm(x, "fro")); - CHECK(allclose(norm(arr_one_d, "inf", {}), array({3.0})).item()); - CHECK(allclose(norm(arr_one_d, "-inf", {}), array({1.0})).item()); + x = reshape(arange(9), {3, 3}); + CHECK_THROWS(norm(x, "bad ord")); - CHECK(allclose(norm(arr_two_d, "f", {0, 1}), array({14.2828568570857})) - .item()); - CHECK(allclose(norm(arr_two_d, "fro", {0, 1}), array({14.2828568570857})) - .item()); - CHECK(allclose(norm(arr_two_d, "inf", {0, 1}), array({21.0})).item()); - CHECK(allclose(norm(arr_two_d, "-inf", {0, 1}), array({3.0})).item()); + CHECK_EQ( + norm(x, "f", std::vector{0, 1}).item(), + doctest::Approx(14.2828568570857)); + CHECK_EQ( + norm(x, "fro", std::vector{0, 1}).item(), + doctest::Approx(14.2828568570857)); + x = reshape(arange(18), {2, 3, 3}); CHECK(allclose( - norm(arr_three_d, "fro", {0, 1}), + norm(x, "fro", std::vector{0, 1}), array({22.24859546, 24.31049156, 26.43860813})) .item()); CHECK(allclose( - norm(arr_three_d, "fro", {1, 2}), array({14.28285686, 39.7617907})) + norm(x, "fro", std::vector{1, 2}), + array({14.28285686, 39.7617907})) .item()); CHECK(allclose( - norm(arr_three_d, "f", {0, 1}), + norm(x, "f", std::vector{0, 1}), array({22.24859546, 24.31049156, 26.43860813})) .item()); CHECK(allclose( - norm(arr_three_d, "f", {1, 0}), + norm(x, "f", std::vector{1, 0}), array({22.24859546, 24.31049156, 26.43860813})) .item()); - CHECK( - allclose(norm(arr_three_d, "f", {1, 2}), array({14.28285686, 39.7617907})) - .item()); - CHECK( - allclose(norm(arr_three_d, "f", {2, 1}), array({14.28285686, 39.7617907})) - .item()); - CHECK(allclose(norm(arr_three_d, "inf", {0, 1}), array({36., 39., 42.})) + CHECK(allclose( + norm(x, "f", std::vector{1, 2}), + array({14.28285686, 39.7617907})) .item()); - CHECK(allclose(norm(arr_three_d, "inf", {1, 2}), array({21., 48.})) + CHECK(allclose( + norm(x, "f", std::vector{2, 1}), + array({14.28285686, 39.7617907})) .item()); - CHECK(allclose(norm(arr_three_d, "-inf", {0, 1}), array({9., 12., 15.})) - .item()); - CHECK(allclose(norm(arr_three_d, "-inf", {1, 2}), array({3., 30.})) - .item()); - CHECK(allclose(norm(arr_three_d, "-inf", {1, 0}), array({9., 11., 13.})) - .item()); - CHECK(allclose(norm(arr_three_d, "-inf", {2, 1}), array({9., 36.})) - .item()); -} \ No newline at end of file +}