mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
linalg.norm (#187)
* implemented vector_norm in cpp added linalg to mlx * implemented vector_norm python binding * renamed vector_norm to norm, implemented norm without provided ord * completed the implementation of the norm * added tests * removed unused import in linalg.cpp * updated python bindings * added some tests for python bindings * handling inf, -inf as numpy does, more extensive tests of compatibility with numpy * added better docs and examples * refactored mlx.linalg.norm bindings * reused existing util for implementation of linalg.norm * more tests * fixed a bug with no ord and axis provided * removed unused imports * some style and API consistency updates to linalg norm * remove unused includes * fix python tests * fixed a bug with frobenius norm of a complex-valued matrix * complex for vector too --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
447bc089b9
commit
6b0d30bb85
@ -57,6 +57,7 @@ are the CPU and GPU.
|
|||||||
python/random
|
python/random
|
||||||
python/transforms
|
python/transforms
|
||||||
python/fft
|
python/fft
|
||||||
|
python/linalg
|
||||||
python/nn
|
python/nn
|
||||||
python/optimizers
|
python/optimizers
|
||||||
python/tree_utils
|
python/tree_utils
|
||||||
|
11
docs/src/python/linalg.rst
Normal file
11
docs/src/python/linalg.rst
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
.. _linalg:
|
||||||
|
|
||||||
|
Linear Algebra
|
||||||
|
==============
|
||||||
|
|
||||||
|
.. currentmodule:: mlx.core.linalg
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: _autosummary
|
||||||
|
|
||||||
|
norm
|
@ -14,6 +14,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h
|
||||||
)
|
)
|
||||||
|
|
||||||
|
175
mlx/linalg.cpp
Normal file
175
mlx/linalg.cpp
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
#include <ostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/dtype.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
|
|
||||||
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
|
Dtype at_least_float(const Dtype& d) {
|
||||||
|
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array l2_norm(
|
||||||
|
const array& a,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (is_complex(a.dtype())) {
|
||||||
|
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||||
|
} else {
|
||||||
|
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array vector_norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto dtype = at_least_float(a.dtype());
|
||||||
|
if (ord == 0.0) {
|
||||||
|
return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s);
|
||||||
|
} else if (ord == 1.0) {
|
||||||
|
return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
|
||||||
|
} else if (ord == 2.0) {
|
||||||
|
return l2_norm(a, axis, keepdims, s);
|
||||||
|
} else if (ord == std::numeric_limits<double>::infinity()) {
|
||||||
|
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
|
||||||
|
} else if (ord == -std::numeric_limits<double>::infinity()) {
|
||||||
|
return astype(min(abs(a, s), axis, keepdims, s), dtype, s);
|
||||||
|
} else {
|
||||||
|
return power(
|
||||||
|
sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s),
|
||||||
|
array(1.0 / ord, dtype),
|
||||||
|
s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array matrix_norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
auto dtype = at_least_float(a.dtype());
|
||||||
|
auto row_axis = axis[0];
|
||||||
|
auto col_axis = axis[1];
|
||||||
|
if (ord == -1.0) {
|
||||||
|
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
|
||||||
|
return astype(
|
||||||
|
min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == 1.0) {
|
||||||
|
col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0);
|
||||||
|
return astype(
|
||||||
|
max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == std::numeric_limits<double>::infinity()) {
|
||||||
|
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
|
||||||
|
return astype(
|
||||||
|
max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == -std::numeric_limits<double>::infinity()) {
|
||||||
|
row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0);
|
||||||
|
return astype(
|
||||||
|
min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s),
|
||||||
|
dtype,
|
||||||
|
s);
|
||||||
|
} else if (ord == 2.0 || ord == -2.0) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[linalg::norm] Singular value norms are not implemented.");
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline array matrix_norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::vector<int>& axis,
|
||||||
|
bool keepdims,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (ord == "f" || ord == "fro") {
|
||||||
|
return l2_norm(a, axis, keepdims, s);
|
||||||
|
} else if (ord == "nuc") {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
|
bool keepdims /* = false */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (!axis) {
|
||||||
|
return norm(flatten(a, s), std::vector<int>{0}, keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (axis.value().size() > 2) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[linalg::norm] Received too many axes for norm.");
|
||||||
|
}
|
||||||
|
return l2_norm(a, axis.value(), keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
|
bool keepdims /* = false */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
std::vector<int> ax;
|
||||||
|
if (!axis) {
|
||||||
|
ax.resize(a.ndim());
|
||||||
|
std::iota(ax.begin(), ax.end(), 0);
|
||||||
|
} else {
|
||||||
|
ax = axis.value();
|
||||||
|
}
|
||||||
|
if (ax.size() == 1) {
|
||||||
|
return vector_norm(a, ord, ax, keepdims, s);
|
||||||
|
} else if (ax.size() == 2) {
|
||||||
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[linalg::norm] Received too many axes for norm.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::optional<std::vector<int>>& axis /* = std::nullopt */,
|
||||||
|
bool keepdims /* = false */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
std::vector<int> ax;
|
||||||
|
if (!axis) {
|
||||||
|
ax.resize(a.ndim());
|
||||||
|
std::iota(ax.begin(), ax.end(), 0);
|
||||||
|
} else {
|
||||||
|
ax = axis.value();
|
||||||
|
}
|
||||||
|
if (ax.size() != 2) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices,"
|
||||||
|
<< " but received " << ax.size() << " axis/axes.";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::linalg
|
63
mlx/linalg.h
Normal file
63
mlx/linalg.h
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "mlx/array.h"
|
||||||
|
#include "mlx/device.h"
|
||||||
|
#include "mlx/ops.h"
|
||||||
|
#include "mlx/stream.h"
|
||||||
|
|
||||||
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute vector or matrix norms.
|
||||||
|
*
|
||||||
|
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
||||||
|
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
||||||
|
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
|
||||||
|
* for matrices) is computed along the given axes. At most 2 axes can be
|
||||||
|
* specified.
|
||||||
|
* - If both axis and ord are provided, then the corresponding matrix or vector
|
||||||
|
* norm is computed. At most 2 axes can be specified.
|
||||||
|
*/
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array norm(
|
||||||
|
const array& a,
|
||||||
|
const double ord,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
||||||
|
}
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::string& ord,
|
||||||
|
int axis,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {}) {
|
||||||
|
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
||||||
|
}
|
||||||
|
array norm(
|
||||||
|
const array& a,
|
||||||
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
||||||
|
bool keepdims = false,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
inline array
|
||||||
|
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
|
||||||
|
return norm(a, std::vector<int>{axis}, keepdims, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core::linalg
|
@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
#include "mlx/device.h"
|
#include "mlx/device.h"
|
||||||
#include "mlx/fft.h"
|
#include "mlx/fft.h"
|
||||||
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "mlx/random.h"
|
#include "mlx/random.h"
|
||||||
#include "mlx/stream.h"
|
#include "mlx/stream.h"
|
||||||
|
@ -11,6 +11,7 @@ pybind11_add_module(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
|
||||||
|
180
python/src/linalg.cpp
Normal file
180
python/src/linalg.cpp
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include <variant>
|
||||||
|
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
|
#include "mlx/linalg.h"
|
||||||
|
|
||||||
|
#include "python/src/load.h"
|
||||||
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
using namespace py::literals;
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
using namespace mlx::core::linalg;
|
||||||
|
|
||||||
|
void init_linalg(py::module_& parent_module) {
|
||||||
|
py::options options;
|
||||||
|
options.disable_function_signatures();
|
||||||
|
|
||||||
|
auto m = parent_module.def_submodule(
|
||||||
|
"linalg", "mlx.core.linalg: linear algebra routines.");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"norm",
|
||||||
|
[](const array& a,
|
||||||
|
const std::variant<std::monostate, int, double, std::string>& ord_,
|
||||||
|
const std::variant<std::monostate, int, std::vector<int>>& axis_,
|
||||||
|
const bool keepdims,
|
||||||
|
const StreamOrDevice stream) {
|
||||||
|
std::optional<std::vector<int>> axis = std::nullopt;
|
||||||
|
if (auto pv = std::get_if<int>(&axis_); pv) {
|
||||||
|
axis = std::vector<int>{*pv};
|
||||||
|
} else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {
|
||||||
|
axis = *pv;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (std::holds_alternative<std::monostate>(ord_)) {
|
||||||
|
return norm(a, axis, keepdims, stream);
|
||||||
|
} else {
|
||||||
|
if (auto pv = std::get_if<std::string>(&ord_); pv) {
|
||||||
|
return norm(a, *pv, axis, keepdims, stream);
|
||||||
|
}
|
||||||
|
double ord;
|
||||||
|
if (auto pv = std::get_if<int>(&ord_); pv) {
|
||||||
|
ord = *pv;
|
||||||
|
} else {
|
||||||
|
ord = std::get<double>(ord_);
|
||||||
|
}
|
||||||
|
return norm(a, ord, axis, keepdims, stream);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"ord"_a = none,
|
||||||
|
"axis"_a = none,
|
||||||
|
"keepdims"_a = false,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Matrix or vector norm.
|
||||||
|
|
||||||
|
This function computes vector or matrix norms depending on the value of
|
||||||
|
the ``ord`` and ``axis`` parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
|
||||||
|
unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
|
||||||
|
2-norm of ``a.flatten`` will be returned.
|
||||||
|
ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
|
||||||
|
If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
|
||||||
|
along the given ``axis``. Default: ``None``.
|
||||||
|
axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
|
||||||
|
axis of ``a`` along which to compute the vector norms. If ``axis`` is a
|
||||||
|
2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
|
||||||
|
norms of these matrices are computed. If `axis` is ``None`` then
|
||||||
|
either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is
|
||||||
|
2-D) is returned. Default: ``None``.
|
||||||
|
keepdims (bool, optional): If ``True``, the axes which are normed over are
|
||||||
|
left in the result as dimensions with size one. Default ``False``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The output containing the norm(s).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
For values of ``ord < 1``, the result is, strictly speaking, not a
|
||||||
|
mathematical norm, but it may still be useful for various numerical
|
||||||
|
purposes.
|
||||||
|
|
||||||
|
The following norms can be calculated:
|
||||||
|
|
||||||
|
===== ============================ ==========================
|
||||||
|
ord norm for matrices norm for vectors
|
||||||
|
===== ============================ ==========================
|
||||||
|
None Frobenius norm 2-norm
|
||||||
|
'fro' Frobenius norm --
|
||||||
|
inf max(sum(abs(x), axis=1)) max(abs(x))
|
||||||
|
-inf min(sum(abs(x), axis=1)) min(abs(x))
|
||||||
|
0 -- sum(x != 0)
|
||||||
|
1 max(sum(abs(x), axis=0)) as below
|
||||||
|
-1 min(sum(abs(x), axis=0)) as below
|
||||||
|
2 2-norm (largest sing. value) as below
|
||||||
|
-2 smallest singular value as below
|
||||||
|
other -- sum(abs(x)**ord)**(1./ord)
|
||||||
|
===== ============================ ==========================
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
Nuclear norm and norms based on singular values are not yet implemented.
|
||||||
|
|
||||||
|
The Frobenius norm is given by [1]_:
|
||||||
|
|
||||||
|
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
||||||
|
|
||||||
|
The nuclear norm is the sum of the singular values.
|
||||||
|
|
||||||
|
Both the Frobenius and nuclear norm orders are only defined for
|
||||||
|
matrices and raise a ``ValueError`` when ``a.ndim != 2``.
|
||||||
|
|
||||||
|
References:
|
||||||
|
.. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
|
||||||
|
Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mlx.core as mx
|
||||||
|
>>> from mlx.core import linalg as la
|
||||||
|
>>> a = mx.arange(9) - 4
|
||||||
|
>>> a
|
||||||
|
array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
|
||||||
|
>>> b = a.reshape((3,3))
|
||||||
|
>>> b
|
||||||
|
array([[-4, -3, -2],
|
||||||
|
[-1, 0, 1],
|
||||||
|
[ 2, 3, 4]], dtype=int32)
|
||||||
|
>>> la.norm(a)
|
||||||
|
array(7.74597, dtype=float32)
|
||||||
|
>>> la.norm(b)
|
||||||
|
array(7.74597, dtype=float32)
|
||||||
|
>>> la.norm(b, 'fro')
|
||||||
|
array(7.74597, dtype=float32)
|
||||||
|
>>> la.norm(a, float("inf"))
|
||||||
|
array(4, dtype=float32)
|
||||||
|
>>> la.norm(b, float("inf"))
|
||||||
|
array(9, dtype=float32)
|
||||||
|
>>> la.norm(a, -float("inf"))
|
||||||
|
array(0, dtype=float32)
|
||||||
|
>>> la.norm(b, -float("inf"))
|
||||||
|
array(2, dtype=float32)
|
||||||
|
>>> la.norm(a, 1)
|
||||||
|
array(20, dtype=float32)
|
||||||
|
>>> la.norm(b, 1)
|
||||||
|
array(7, dtype=float32)
|
||||||
|
>>> la.norm(a, -1)
|
||||||
|
array(0, dtype=float32)
|
||||||
|
>>> la.norm(b, -1)
|
||||||
|
array(6, dtype=float32)
|
||||||
|
>>> la.norm(a, 2)
|
||||||
|
array(7.74597, dtype=float32)
|
||||||
|
>>> la.norm(a, 3)
|
||||||
|
array(5.84804, dtype=float32)
|
||||||
|
>>> la.norm(a, -3)
|
||||||
|
array(0, dtype=float32)
|
||||||
|
>>> c = mx.array([[ 1, 2, 3],
|
||||||
|
... [-1, 1, 4]])
|
||||||
|
>>> la.norm(c, axis=0)
|
||||||
|
array([1.41421, 2.23607, 5], dtype=float32)
|
||||||
|
>>> la.norm(c, axis=1)
|
||||||
|
array([3.74166, 4.24264], dtype=float32)
|
||||||
|
>>> la.norm(c, ord=1, axis=1)
|
||||||
|
array([6, 6], dtype=float32)
|
||||||
|
>>> m = mx.arange(8).reshape(2,2,2)
|
||||||
|
>>> la.norm(m, axis=(1,2))
|
||||||
|
array([3.74166, 11.225], dtype=float32)
|
||||||
|
>>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
|
||||||
|
(array(3.74166, dtype=float32), array(11.225, dtype=float32))
|
||||||
|
)pbdoc");
|
||||||
|
}
|
@ -15,6 +15,7 @@ void init_ops(py::module_&);
|
|||||||
void init_transforms(py::module_&);
|
void init_transforms(py::module_&);
|
||||||
void init_random(py::module_&);
|
void init_random(py::module_&);
|
||||||
void init_fft(py::module_&);
|
void init_fft(py::module_&);
|
||||||
|
void init_linalg(py::module_&);
|
||||||
|
|
||||||
PYBIND11_MODULE(core, m) {
|
PYBIND11_MODULE(core, m) {
|
||||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||||
@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) {
|
|||||||
init_transforms(m);
|
init_transforms(m);
|
||||||
init_random(m);
|
init_random(m);
|
||||||
init_fft(m);
|
init_fft(m);
|
||||||
|
init_linalg(m);
|
||||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||||
}
|
}
|
||||||
|
94
python/tests/test_linalg.py
Normal file
94
python/tests/test_linalg.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class TestLinalg(mlx_tests.MLXTestCase):
|
||||||
|
def test_norm(self):
|
||||||
|
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
|
||||||
|
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
||||||
|
|
||||||
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
|
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
# Test when at least one axis is provided
|
||||||
|
for num_axes in range(1, len(shape)):
|
||||||
|
if num_axes == 1:
|
||||||
|
ords = vector_ords
|
||||||
|
else:
|
||||||
|
ords = matrix_ords
|
||||||
|
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||||
|
for keepdims in [True, False]:
|
||||||
|
for o in ords:
|
||||||
|
out_np = np.linalg.norm(
|
||||||
|
x_np, ord=o, axis=axis, keepdims=keepdims
|
||||||
|
)
|
||||||
|
out_mx = mx.linalg.norm(
|
||||||
|
x_mx, ord=o, axis=axis, keepdims=keepdims
|
||||||
|
)
|
||||||
|
with self.subTest(
|
||||||
|
shape=shape, ord=o, axis=axis, keepdims=keepdims
|
||||||
|
):
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test only ord provided
|
||||||
|
for shape in [(3,), (2, 3)]:
|
||||||
|
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
for o in [None, 1, -1, float("inf"), -float("inf")]:
|
||||||
|
for keepdims in [True, False]:
|
||||||
|
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
|
||||||
|
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
|
||||||
|
with self.subTest(shape=shape, ord=o, keepdims=keepdims):
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test no ord and no axis provided
|
||||||
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
|
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||||
|
for keepdims in [True, False]:
|
||||||
|
out_np = np.linalg.norm(x_np, keepdims=keepdims)
|
||||||
|
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
|
||||||
|
with self.subTest(shape=shape, keepdims=keepdims):
|
||||||
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||||
|
|
||||||
|
def test_complex_norm(self):
|
||||||
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
|
x_np = np.random.uniform(size=shape).astype(
|
||||||
|
np.float32
|
||||||
|
) + 1j * np.random.uniform(size=shape).astype(np.float32)
|
||||||
|
x_mx = mx.array(x_np)
|
||||||
|
out_np = np.linalg.norm(x_np)
|
||||||
|
out_mx = mx.linalg.norm(x_mx)
|
||||||
|
with self.subTest(shape=shape):
|
||||||
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||||
|
for num_axes in range(1, len(shape)):
|
||||||
|
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||||
|
out_np = np.linalg.norm(x_np, axis=axis)
|
||||||
|
out_mx = mx.linalg.norm(x_mx, axis=axis)
|
||||||
|
with self.subTest(shape=shape, axis=axis):
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||||
|
)
|
||||||
|
|
||||||
|
x_np = np.random.uniform(size=(4, 4)).astype(
|
||||||
|
np.float32
|
||||||
|
) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32)
|
||||||
|
x_mx = mx.array(x_np)
|
||||||
|
out_np = np.linalg.norm(x_np, ord="fro")
|
||||||
|
out_mx = mx.linalg.norm(x_mx, ord="fro")
|
||||||
|
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -31,6 +31,7 @@ target_sources(tests PRIVATE
|
|||||||
scheduler_tests.cpp
|
scheduler_tests.cpp
|
||||||
utils_tests.cpp
|
utils_tests.cpp
|
||||||
vmap_tests.cpp
|
vmap_tests.cpp
|
||||||
|
linalg_tests.cpp
|
||||||
${METAL_TEST_SOURCES}
|
${METAL_TEST_SOURCES}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
250
tests/linalg_tests.cpp
Normal file
250
tests/linalg_tests.cpp
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
|
#include "doctest/doctest.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
|
using namespace mlx::core;
|
||||||
|
using namespace mlx::core::linalg;
|
||||||
|
|
||||||
|
TEST_CASE("[mlx.core.linalg.norm] no ord") {
|
||||||
|
// Zero dimensions
|
||||||
|
array x(2.0);
|
||||||
|
CHECK_EQ(norm(x).item<float>(), 2.0f);
|
||||||
|
CHECK_THROWS(norm(x, 0));
|
||||||
|
|
||||||
|
x = array({1, 2, 3});
|
||||||
|
float expected = std::sqrt(1 + 4 + 9);
|
||||||
|
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, 0, false).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, -1, false).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, -1, true).ndim(), 1);
|
||||||
|
CHECK_THROWS(norm(x, 1));
|
||||||
|
|
||||||
|
x = reshape(arange(9), {3, 3});
|
||||||
|
expected =
|
||||||
|
std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8);
|
||||||
|
|
||||||
|
CHECK_EQ(norm(x).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, std::vector<int>{0, 1}).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK(array_equal(
|
||||||
|
norm(x, 0, false),
|
||||||
|
array(
|
||||||
|
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||||
|
std::sqrt(1 + 4 * 4 + 7 * 7),
|
||||||
|
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 1, false),
|
||||||
|
array(
|
||||||
|
{std::sqrt(0 + 1 + 2 * 2),
|
||||||
|
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
|
std::sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
x = reshape(arange(18), {2, 3, 3});
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 2, false),
|
||||||
|
array(
|
||||||
|
{
|
||||||
|
std::sqrt(0 + 1 + 2 * 2),
|
||||||
|
std::sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
|
std::sqrt(6 * 6 + 7 * 7 + 8 * 8),
|
||||||
|
std::sqrt(9 * 9 + 10 * 10 + 11 * 11),
|
||||||
|
std::sqrt(12 * 12 + 13 * 13 + 14 * 14),
|
||||||
|
std::sqrt(15 * 15 + 16 * 16 + 17 * 17),
|
||||||
|
},
|
||||||
|
{2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, std::vector<int>{1, 2}, false),
|
||||||
|
array(
|
||||||
|
{std::sqrt(
|
||||||
|
0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 +
|
||||||
|
8 * 8),
|
||||||
|
std::sqrt(
|
||||||
|
9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 +
|
||||||
|
15 * 15 + 16 * 16 + 17 * 17)},
|
||||||
|
{2}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK_THROWS(norm(x, std::vector<int>{0, 1, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
||||||
|
CHECK_THROWS(norm(array(0), 2.0));
|
||||||
|
|
||||||
|
array x({1, 2, 3});
|
||||||
|
|
||||||
|
float expected = std::sqrt(1 + 4 + 9);
|
||||||
|
CHECK_EQ(norm(x, 2.0).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_EQ(norm(x, 2.0, 0).item<float>(), doctest::Approx(expected));
|
||||||
|
CHECK_THROWS(norm(x, 2.0, 1));
|
||||||
|
|
||||||
|
expected = 1 + 2 + 3;
|
||||||
|
CHECK_EQ(norm(x, 1.0).item<float>(), doctest::Approx(expected));
|
||||||
|
|
||||||
|
expected = 3;
|
||||||
|
CHECK_EQ(norm(x, 0.0).item<float>(), doctest::Approx(expected));
|
||||||
|
|
||||||
|
expected = 3;
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, std::numeric_limits<double>::infinity()).item<float>(),
|
||||||
|
doctest::Approx(expected));
|
||||||
|
|
||||||
|
expected = 1;
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
|
||||||
|
doctest::Approx(expected));
|
||||||
|
|
||||||
|
x = reshape(arange(9), {3, 3});
|
||||||
|
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 2.0, 0, false),
|
||||||
|
array(
|
||||||
|
{std::sqrt(0 + 3 * 3 + 6 * 6),
|
||||||
|
std::sqrt(1 + 4 * 4 + 7 * 7),
|
||||||
|
std::sqrt(2 * 2 + 5 * 5 + 8 * 8)}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 2.0, 1, false),
|
||||||
|
array(
|
||||||
|
{sqrt(0 + 1 + 2 * 2),
|
||||||
|
sqrt(3 * 3 + 4 * 4 + 5 * 5),
|
||||||
|
sqrt(6 * 6 + 7 * 7 + 8 * 8)}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(15.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{1, 0}).item<float>(),
|
||||||
|
doctest::Approx(21.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(9.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
||||||
|
doctest::Approx(3.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{0, 1}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{1, 0}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{0, 1}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{1, 0}, true).shape(),
|
||||||
|
std::vector<int>{1, 1});
|
||||||
|
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||||
|
doctest::Approx(9.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||||
|
doctest::Approx(15.0));
|
||||||
|
|
||||||
|
x = reshape(arange(18), {2, 3, 3});
|
||||||
|
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 3.0, 0),
|
||||||
|
array(
|
||||||
|
{9.,
|
||||||
|
10.00333222,
|
||||||
|
11.02199456,
|
||||||
|
12.06217728,
|
||||||
|
13.12502645,
|
||||||
|
14.2094363,
|
||||||
|
15.31340617,
|
||||||
|
16.43469751,
|
||||||
|
17.57113899},
|
||||||
|
{3, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 3.0, 2),
|
||||||
|
array(
|
||||||
|
{2.08008382,
|
||||||
|
6.,
|
||||||
|
10.23127655,
|
||||||
|
14.5180117,
|
||||||
|
18.82291607,
|
||||||
|
23.13593104},
|
||||||
|
{2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(
|
||||||
|
allclose(
|
||||||
|
norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 1.0, 0),
|
||||||
|
array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3}))
|
||||||
|
.item<bool>());
|
||||||
|
|
||||||
|
CHECK(allclose(norm(x, 1.0, std::vector<int>{0, 1}), array({21., 23., 25.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, 1.0, std::vector<int>{1, 2}), array({15., 42.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{0, 1}), array({9., 11., 13.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9., 36.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 0}), array({9., 12., 15.}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{2, 1}), array({3, 30}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
|
||||||
|
.item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||||
|
array x({1, 2, 3});
|
||||||
|
CHECK_THROWS(norm(x, "fro"));
|
||||||
|
|
||||||
|
x = reshape(arange(9), {3, 3});
|
||||||
|
CHECK_THROWS(norm(x, "bad ord"));
|
||||||
|
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, "f", std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(14.2828568570857));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
|
||||||
|
doctest::Approx(14.2828568570857));
|
||||||
|
|
||||||
|
x = reshape(arange(18), {2, 3, 3});
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "fro", std::vector<int>{0, 1}),
|
||||||
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "fro", std::vector<int>{1, 2}),
|
||||||
|
array({14.28285686, 39.7617907}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "f", std::vector<int>{0, 1}),
|
||||||
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "f", std::vector<int>{1, 0}),
|
||||||
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "f", std::vector<int>{1, 2}),
|
||||||
|
array({14.28285686, 39.7617907}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "f", std::vector<int>{2, 1}),
|
||||||
|
array({14.28285686, 39.7617907}))
|
||||||
|
.item<bool>());
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user