mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-01 08:38:12 +08:00
@@ -92,6 +92,7 @@ void init_linalg(nb::module_& parent_module) {
|
||||
===== ============================ ==========================
|
||||
None Frobenius norm 2-norm
|
||||
'fro' Frobenius norm --
|
||||
'nuc' nuclear norm --
|
||||
inf max(sum(abs(x), axis=1)) max(abs(x))
|
||||
-inf min(sum(abs(x), axis=1)) min(abs(x))
|
||||
0 -- sum(x != 0)
|
||||
@@ -102,9 +103,6 @@ void init_linalg(nb::module_& parent_module) {
|
||||
other -- sum(abs(x)**ord)**(1./ord)
|
||||
===== ============================ ==========================
|
||||
|
||||
.. warning::
|
||||
Nuclear norm and norms based on singular values are not yet implemented.
|
||||
|
||||
The Frobenius norm is given by [1]_:
|
||||
|
||||
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
||||
@@ -206,15 +204,22 @@ void init_linalg(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"svd",
|
||||
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
||||
const auto result = mx::linalg::svd(a, s);
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
[](const mx::array& a,
|
||||
bool compute_uv /* = true */,
|
||||
mx::StreamOrDevice s /* = {} */) -> nb::object {
|
||||
const auto result = mx::linalg::svd(a, compute_uv, s);
|
||||
if (result.size() == 1) {
|
||||
return nb::cast(result.at(0));
|
||||
} else {
|
||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||
}
|
||||
},
|
||||
"a"_a,
|
||||
"compute_uv"_a = true,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||
"def svd(a: array, compute_uv: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||
R"pbdoc(
|
||||
The Singular Value Decomposition (SVD) of the input matrix.
|
||||
|
||||
@@ -224,12 +229,15 @@ void init_linalg(nb::module_& parent_module) {
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
compute_uv (bool, optional): If ``True``, return the ``U``, ``S``, and ``Vt`` components.
|
||||
If ``False``, return only the ``S`` array. Default: ``True``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that
|
||||
``A = U @ diag(S) @ Vt``
|
||||
Union[tuple(array, ...), array]:
|
||||
If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that
|
||||
``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"inv",
|
||||
|
||||
Reference in New Issue
Block a user