mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
parent
9680f72cca
commit
3835a428c5
@ -8,7 +8,7 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void svd_impl(const array& a, array& u, array& s, array& vt) {
|
void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
|
||||||
// Lapack uses the column-major convention. To avoid having to transpose
|
// Lapack uses the column-major convention. To avoid having to transpose
|
||||||
// the input and then transpose the outputs, we swap the indices/sizes of the
|
// the input and then transpose the outputs, we swap the indices/sizes of the
|
||||||
// matrices and take advantage of the following identity (see
|
// matrices and take advantage of the following identity (see
|
||||||
@ -35,13 +35,8 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
|||||||
array in(a.shape(), a.dtype(), nullptr, {});
|
array in(a.shape(), a.dtype(), nullptr, {});
|
||||||
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
|
||||||
|
|
||||||
// Allocate outputs.
|
auto job_u = (u_data && vt_data) ? "V" : "N";
|
||||||
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
auto job_vt = (u_data && vt_data) ? "V" : "N";
|
||||||
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
|
||||||
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
|
||||||
|
|
||||||
static constexpr auto job_u = "V";
|
|
||||||
static constexpr auto job_vt = "V";
|
|
||||||
static constexpr auto range = "A";
|
static constexpr auto range = "A";
|
||||||
|
|
||||||
// Will contain the number of singular values after the call has returned.
|
// Will contain the number of singular values after the call has returned.
|
||||||
@ -56,6 +51,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
|||||||
|
|
||||||
static const int ignored_int = 0;
|
static const int ignored_int = 0;
|
||||||
static const T ignored_float = 0;
|
static const T ignored_float = 0;
|
||||||
|
static T ignored_output = 0;
|
||||||
|
|
||||||
int info;
|
int info;
|
||||||
|
|
||||||
@ -109,12 +105,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
|||||||
/* il = */ &ignored_int,
|
/* il = */ &ignored_int,
|
||||||
/* iu = */ &ignored_int,
|
/* iu = */ &ignored_int,
|
||||||
/* ns = */ &ns,
|
/* ns = */ &ns,
|
||||||
/* s = */ s.data<T>() + K * i,
|
/* s = */ s_data + K * i,
|
||||||
// According to the identity above, lapack will write Vᵀᵀ as U.
|
// According to the identity above, lapack will write Vᵀᵀ as U.
|
||||||
/* u = */ vt.data<T>() + N * N * i,
|
/* u = */ vt_data ? vt_data + N * N * i : &ignored_output,
|
||||||
/* ldu = */ &ldu,
|
/* ldu = */ &ldu,
|
||||||
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
// According to the identity above, lapack will write Uᵀ as Vᵀ.
|
||||||
/* vt = */ u.data<T>() + M * M * i,
|
/* vt = */ u_data ? u_data + M * M * i : &ignored_output,
|
||||||
/* ldvt = */ &ldvt,
|
/* ldvt = */ &ldvt,
|
||||||
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
|
||||||
/* lwork = */ &lwork,
|
/* lwork = */ &lwork,
|
||||||
@ -136,15 +132,36 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void compute_svd(const array& a, bool compute_uv, std::vector<array>& outputs) {
|
||||||
|
if (compute_uv) {
|
||||||
|
array& u = outputs[0];
|
||||||
|
array& s = outputs[1];
|
||||||
|
array& vt = outputs[2];
|
||||||
|
|
||||||
|
u.set_data(allocator::malloc_or_wait(u.nbytes()));
|
||||||
|
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||||
|
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
|
||||||
|
|
||||||
|
svd_impl<T>(a, u.data<T>(), s.data<T>(), vt.data<T>());
|
||||||
|
} else {
|
||||||
|
array& s = outputs[0];
|
||||||
|
|
||||||
|
s.set_data(allocator::malloc_or_wait(s.nbytes()));
|
||||||
|
|
||||||
|
svd_impl<T>(a, nullptr, s.data<T>(), nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SVD::eval_cpu(
|
void SVD::eval_cpu(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
switch (inputs[0].dtype()) {
|
switch (inputs[0].dtype()) {
|
||||||
case float32:
|
case float32:
|
||||||
svd_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
compute_svd<float>(inputs[0], compute_uv_, outputs);
|
||||||
break;
|
break;
|
||||||
case float64:
|
case float64:
|
||||||
svd_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
|
compute_svd<double>(inputs[0], compute_uv_, outputs);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
@ -102,8 +102,21 @@ inline array matrix_norm(
|
|||||||
dtype,
|
dtype,
|
||||||
s);
|
s);
|
||||||
} else if (ord == 2.0 || ord == -2.0) {
|
} else if (ord == 2.0 || ord == -2.0) {
|
||||||
throw std::runtime_error(
|
row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];
|
||||||
"[linalg::norm] Singular value norms are not implemented.");
|
col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];
|
||||||
|
auto a_matrix = (row_axis > col_axis)
|
||||||
|
? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)
|
||||||
|
: moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);
|
||||||
|
a_matrix = svd(a_matrix, false, s).at(0);
|
||||||
|
a_matrix = (ord == 2.0) ? max(a_matrix, -1, false, s)
|
||||||
|
: min(a_matrix, -1, false, s);
|
||||||
|
if (keepdims) {
|
||||||
|
std::vector<int> sorted_axes = (row_axis < col_axis)
|
||||||
|
? std::vector<int>{row_axis, col_axis}
|
||||||
|
: std::vector<int>{col_axis, row_axis};
|
||||||
|
a_matrix = expand_dims(a_matrix, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return astype(a_matrix, dtype, s);
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
|
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
|
||||||
@ -120,8 +133,19 @@ inline array matrix_norm(
|
|||||||
if (ord == "f" || ord == "fro") {
|
if (ord == "f" || ord == "fro") {
|
||||||
return l2_norm(a, axis, keepdims, s);
|
return l2_norm(a, axis, keepdims, s);
|
||||||
} else if (ord == "nuc") {
|
} else if (ord == "nuc") {
|
||||||
throw std::runtime_error(
|
int row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];
|
||||||
"[linalg::norm] Nuclear norm not yet implemented.");
|
int col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];
|
||||||
|
auto a_matrix = (row_axis > col_axis)
|
||||||
|
? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)
|
||||||
|
: moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);
|
||||||
|
a_matrix = sum(svd(a_matrix, false, s).at(0), -1, false, s);
|
||||||
|
if (keepdims) {
|
||||||
|
std::vector<int> sorted_axes = (row_axis < col_axis)
|
||||||
|
? std::vector<int>{row_axis, col_axis}
|
||||||
|
: std::vector<int>{col_axis, row_axis};
|
||||||
|
a_matrix = expand_dims(a_matrix, sorted_axes, s);
|
||||||
|
}
|
||||||
|
return a_matrix;
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
|
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
|
||||||
@ -214,7 +238,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return std::make_pair(out[0], out[1]);
|
return std::make_pair(out[0], out[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
std::vector<array>
|
||||||
|
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||||
check_cpu_stream(s, "[linalg::svd]");
|
check_cpu_stream(s, "[linalg::svd]");
|
||||||
check_float(a.dtype(), "[linalg::svd]");
|
check_float(a.dtype(), "[linalg::svd]");
|
||||||
|
|
||||||
@ -230,14 +255,22 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
const auto n = a.shape(-1);
|
const auto n = a.shape(-1);
|
||||||
const auto rank = a.ndim();
|
const auto rank = a.ndim();
|
||||||
|
|
||||||
auto u_shape = a.shape();
|
|
||||||
u_shape[rank - 2] = m;
|
|
||||||
u_shape[rank - 1] = m;
|
|
||||||
|
|
||||||
auto s_shape = a.shape();
|
auto s_shape = a.shape();
|
||||||
s_shape.pop_back();
|
s_shape.pop_back();
|
||||||
s_shape[rank - 2] = std::min(m, n);
|
s_shape[rank - 2] = std::min(m, n);
|
||||||
|
|
||||||
|
if (!compute_uv) {
|
||||||
|
return {array(
|
||||||
|
std::move(s_shape),
|
||||||
|
std::move(a.dtype()),
|
||||||
|
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||||
|
{a})};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto u_shape = a.shape();
|
||||||
|
u_shape[rank - 2] = m;
|
||||||
|
u_shape[rank - 1] = m;
|
||||||
|
|
||||||
auto vt_shape = a.shape();
|
auto vt_shape = a.shape();
|
||||||
vt_shape[rank - 2] = n;
|
vt_shape[rank - 2] = n;
|
||||||
vt_shape[rank - 1] = n;
|
vt_shape[rank - 1] = n;
|
||||||
@ -245,7 +278,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return array::make_arrays(
|
return array::make_arrays(
|
||||||
{u_shape, s_shape, vt_shape},
|
{u_shape, s_shape, vt_shape},
|
||||||
{a.dtype(), a.dtype(), a.dtype()},
|
{a.dtype(), a.dtype(), a.dtype()},
|
||||||
std::make_shared<SVD>(to_stream(s)),
|
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||||
{a});
|
{a});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,7 +356,7 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
int m = a.shape(-2);
|
int m = a.shape(-2);
|
||||||
int n = a.shape(-1);
|
int n = a.shape(-1);
|
||||||
int k = std::min(m, n);
|
int k = std::min(m, n);
|
||||||
auto outs = linalg::svd(a, s);
|
auto outs = linalg::svd(a, true, s);
|
||||||
array U = outs[0];
|
array U = outs[0];
|
||||||
array S = outs[1];
|
array S = outs[1];
|
||||||
array V = outs[2];
|
array V = outs[2];
|
||||||
|
@ -62,7 +62,11 @@ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
|
|||||||
|
|
||||||
std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
|
std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
std::vector<array> svd(const array& a, StreamOrDevice s = {});
|
std::vector<array>
|
||||||
|
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */);
|
||||||
|
inline std::vector<array> svd(const array& a, StreamOrDevice s = {}) {
|
||||||
|
return svd(a, true, s);
|
||||||
|
}
|
||||||
|
|
||||||
array inv(const array& a, StreamOrDevice s = {});
|
array inv(const array& a, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
@ -4940,7 +4940,8 @@ std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
|
|||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
auto ax = axes[0] >= 0 ? 0 : -1;
|
auto ax = axes[0] >= 0 ? 0 : -1;
|
||||||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
|
||||||
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
|
std::vector<int> new_axes(compute_uv_ ? 3 : 1, ax);
|
||||||
|
return {linalg::svd(a, compute_uv_, stream()), std::move(new_axes)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(
|
||||||
|
@ -2287,7 +2287,8 @@ class QRF : public Primitive {
|
|||||||
/* SVD primitive. */
|
/* SVD primitive. */
|
||||||
class SVD : public Primitive {
|
class SVD : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit SVD(Stream stream) : Primitive(stream) {}
|
explicit SVD(Stream stream, bool compute_uv)
|
||||||
|
: Primitive(stream), compute_uv_(compute_uv) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
override;
|
override;
|
||||||
@ -2296,6 +2297,12 @@ class SVD : public Primitive {
|
|||||||
|
|
||||||
DEFINE_VMAP()
|
DEFINE_VMAP()
|
||||||
DEFINE_PRINT(SVD)
|
DEFINE_PRINT(SVD)
|
||||||
|
auto state() const {
|
||||||
|
return compute_uv_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool compute_uv_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Matrix inversion primitive. */
|
/* Matrix inversion primitive. */
|
||||||
|
@ -244,7 +244,7 @@ array multivariate_normal(
|
|||||||
|
|
||||||
// Compute the square-root of the covariance matrix, using the SVD
|
// Compute the square-root of the covariance matrix, using the SVD
|
||||||
auto covariance = astype(cov, float32, stream);
|
auto covariance = astype(cov, float32, stream);
|
||||||
auto SVD = linalg::svd(covariance, stream);
|
auto SVD = linalg::svd(covariance, true, stream);
|
||||||
auto std = astype(
|
auto std = astype(
|
||||||
matmul(
|
matmul(
|
||||||
multiply(
|
multiply(
|
||||||
|
@ -92,6 +92,7 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
===== ============================ ==========================
|
===== ============================ ==========================
|
||||||
None Frobenius norm 2-norm
|
None Frobenius norm 2-norm
|
||||||
'fro' Frobenius norm --
|
'fro' Frobenius norm --
|
||||||
|
'nuc' nuclear norm --
|
||||||
inf max(sum(abs(x), axis=1)) max(abs(x))
|
inf max(sum(abs(x), axis=1)) max(abs(x))
|
||||||
-inf min(sum(abs(x), axis=1)) min(abs(x))
|
-inf min(sum(abs(x), axis=1)) min(abs(x))
|
||||||
0 -- sum(x != 0)
|
0 -- sum(x != 0)
|
||||||
@ -102,9 +103,6 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
other -- sum(abs(x)**ord)**(1./ord)
|
other -- sum(abs(x)**ord)**(1./ord)
|
||||||
===== ============================ ==========================
|
===== ============================ ==========================
|
||||||
|
|
||||||
.. warning::
|
|
||||||
Nuclear norm and norms based on singular values are not yet implemented.
|
|
||||||
|
|
||||||
The Frobenius norm is given by [1]_:
|
The Frobenius norm is given by [1]_:
|
||||||
|
|
||||||
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
:math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
|
||||||
@ -206,15 +204,22 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"svd",
|
"svd",
|
||||||
[](const mx::array& a, mx::StreamOrDevice s /* = {} */) {
|
[](const mx::array& a,
|
||||||
const auto result = mx::linalg::svd(a, s);
|
bool compute_uv /* = true */,
|
||||||
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
mx::StreamOrDevice s /* = {} */) -> nb::object {
|
||||||
|
const auto result = mx::linalg::svd(a, compute_uv, s);
|
||||||
|
if (result.size() == 1) {
|
||||||
|
return nb::cast(result.at(0));
|
||||||
|
} else {
|
||||||
|
return nb::make_tuple(result.at(0), result.at(1), result.at(2));
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"a"_a,
|
"a"_a,
|
||||||
|
"compute_uv"_a = true,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def svd(a: array, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
"def svd(a: array, compute_uv: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
The Singular Value Decomposition (SVD) of the input matrix.
|
The Singular Value Decomposition (SVD) of the input matrix.
|
||||||
|
|
||||||
@ -224,12 +229,15 @@ void init_linalg(nb::module_& parent_module) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (array): Input array.
|
a (array): Input array.
|
||||||
|
compute_uv (bool, optional): If ``True``, return the ``U``, ``S``, and ``Vt`` components.
|
||||||
|
If ``False``, return only the ``S`` array. Default: ``True``.
|
||||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
in which case the default stream of the default device is used.
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple(array, array, array): The ``U``, ``S``, and ``Vt`` matrices, such that
|
Union[tuple(array, ...), array]:
|
||||||
``A = U @ diag(S) @ Vt``
|
If compute_uv is ``True`` returns the ``U``, ``S``, and ``Vt`` matrices, such that
|
||||||
|
``A = U @ diag(S) @ Vt``. If compute_uv is ``False`` returns singular values array ``S``.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"inv",
|
"inv",
|
||||||
|
@ -12,11 +12,11 @@ import numpy as np
|
|||||||
class TestLinalg(mlx_tests.MLXTestCase):
|
class TestLinalg(mlx_tests.MLXTestCase):
|
||||||
def test_norm(self):
|
def test_norm(self):
|
||||||
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
|
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
|
||||||
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
matrix_ords = [None, "fro", "nuc", -1, 1, -2, 2, float("inf"), -float("inf")]
|
||||||
|
|
||||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||||
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
x_mx = mx.arange(1, math.prod(shape) + 1, dtype=mx.float32).reshape(shape)
|
||||||
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
x_np = np.arange(1, math.prod(shape) + 1, dtype=np.float32).reshape(shape)
|
||||||
# Test when at least one axis is provided
|
# Test when at least one axis is provided
|
||||||
for num_axes in range(1, len(shape)):
|
for num_axes in range(1, len(shape)):
|
||||||
if num_axes == 1:
|
if num_axes == 1:
|
||||||
@ -26,11 +26,14 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||||
for keepdims in [True, False]:
|
for keepdims in [True, False]:
|
||||||
for o in ords:
|
for o in ords:
|
||||||
|
stream = (
|
||||||
|
mx.cpu if o in ["nuc", -2, 2] else mx.default_device()
|
||||||
|
)
|
||||||
out_np = np.linalg.norm(
|
out_np = np.linalg.norm(
|
||||||
x_np, ord=o, axis=axis, keepdims=keepdims
|
x_np, ord=o, axis=axis, keepdims=keepdims
|
||||||
)
|
)
|
||||||
out_mx = mx.linalg.norm(
|
out_mx = mx.linalg.norm(
|
||||||
x_mx, ord=o, axis=axis, keepdims=keepdims
|
x_mx, ord=o, axis=axis, keepdims=keepdims, stream=stream
|
||||||
)
|
)
|
||||||
with self.subTest(
|
with self.subTest(
|
||||||
shape=shape, ord=o, axis=axis, keepdims=keepdims
|
shape=shape, ord=o, axis=axis, keepdims=keepdims
|
||||||
@ -133,20 +136,38 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
def test_svd_decomposition(self):
|
def test_svd_decomposition(self):
|
||||||
A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32)
|
A = mx.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=mx.float32)
|
||||||
U, S, Vt = mx.linalg.svd(A, stream=mx.cpu)
|
U, S, Vt = mx.linalg.svd(A, compute_uv=True, stream=mx.cpu)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7)
|
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, A, rtol=1e-5, atol=1e-7)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
S = mx.linalg.svd(A, compute_uv=False, stream=mx.cpu)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
mx.linalg.norm(S), mx.linalg.norm(A, ord="fro"), rtol=1e-5, atol=1e-7
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Multiple matrices
|
# Multiple matrices
|
||||||
B = A + 10.0
|
B = A + 10.0
|
||||||
AB = mx.stack([A, B])
|
AB = mx.stack([A, B])
|
||||||
Us, Ss, Vts = mx.linalg.svd(AB, stream=mx.cpu)
|
Us, Ss, Vts = mx.linalg.svd(AB, compute_uv=True, stream=mx.cpu)
|
||||||
for M, U, S, Vt in zip([A, B], Us, Ss, Vts):
|
for M, U, S, Vt in zip([A, B], Us, Ss, Vts):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Ss = mx.linalg.svd(AB, compute_uv=False, stream=mx.cpu)
|
||||||
|
for M, S in zip([A, B], Ss):
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
mx.linalg.norm(S),
|
||||||
|
mx.linalg.norm(M, ord="fro"),
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-7,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_inverse(self):
|
def test_inverse(self):
|
||||||
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
|
||||||
A_inv = mx.linalg.inv(A, stream=mx.cpu)
|
A_inv = mx.linalg.inv(A, stream=mx.cpu)
|
||||||
|
@ -316,33 +316,56 @@ class TestVmap(mlx_tests.MLXTestCase):
|
|||||||
def test_vmap_svd(self):
|
def test_vmap_svd(self):
|
||||||
a = mx.random.uniform(shape=(3, 4, 2))
|
a = mx.random.uniform(shape=(3, 4, 2))
|
||||||
|
|
||||||
cpu_svd = lambda x: mx.linalg.svd(x, stream=mx.cpu)
|
cpu_svd_full = lambda x: mx.linalg.svd(x, compute_uv=True, stream=mx.cpu)
|
||||||
|
cpu_svd_singular = lambda x: mx.linalg.svd(x, compute_uv=False, stream=mx.cpu)
|
||||||
|
|
||||||
# Vmap over the first axis (this is already supported natively by the primitive).
|
# Vmap over the first axis (this is already supported natively by the primitive).
|
||||||
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(0,))(a)
|
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(0,))(a)
|
||||||
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
|
self.assertEqual(Us.shape, (a.shape[0], a.shape[1], a.shape[1]))
|
||||||
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
|
self.assertEqual(Ss.shape, (a.shape[0], a.shape[2]))
|
||||||
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
|
self.assertEqual(Vts.shape, (a.shape[0], a.shape[2], a.shape[2]))
|
||||||
|
|
||||||
|
Sv = mx.vmap(cpu_svd_singular, in_axes=(0,))(a)
|
||||||
|
self.assertEqual(Sv.shape, (a.shape[0], a.shape[2]))
|
||||||
|
|
||||||
for i in range(a.shape[0]):
|
for i in range(a.shape[0]):
|
||||||
M = a[i]
|
M = a[i]
|
||||||
U, S, Vt = Us[i], Ss[i], Vts[i]
|
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||||
)
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
mx.linalg.norm(Sv[i]),
|
||||||
|
mx.linalg.norm(M, ord="fro"),
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-7,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Vmap over the second axis.
|
# Vmap over the second axis.
|
||||||
Us, Ss, Vts = mx.vmap(cpu_svd, in_axes=(1,))(a)
|
Us, Ss, Vts = mx.vmap(cpu_svd_full, in_axes=(1,))(a)
|
||||||
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
|
self.assertEqual(Us.shape, (a.shape[1], a.shape[0], a.shape[0]))
|
||||||
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
|
self.assertEqual(Ss.shape, (a.shape[1], a.shape[2]))
|
||||||
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
|
self.assertEqual(Vts.shape, (a.shape[1], a.shape[2], a.shape[2]))
|
||||||
|
|
||||||
|
Sv = mx.vmap(cpu_svd_singular, in_axes=(1,))(a)
|
||||||
|
self.assertEqual(Sv.shape, (a.shape[1], a.shape[2]))
|
||||||
|
|
||||||
for i in range(a.shape[1]):
|
for i in range(a.shape[1]):
|
||||||
M = a[:, i, :]
|
M = a[:, i, :]
|
||||||
U, S, Vt = Us[i], Ss[i], Vts[i]
|
U, S, Vt = Us[i], Ss[i], Vts[i]
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
mx.allclose(U[:, : len(S)] @ mx.diag(S) @ Vt, M, rtol=1e-5, atol=1e-7)
|
||||||
)
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
mx.linalg.norm(Sv[i]),
|
||||||
|
mx.linalg.norm(M, ord="fro"),
|
||||||
|
rtol=1e-5,
|
||||||
|
atol=1e-7,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_vmap_inverse(self):
|
def test_vmap_inverse(self):
|
||||||
mx.random.seed(42)
|
mx.random.seed(42)
|
||||||
|
@ -100,7 +100,7 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
|||||||
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
|
norm(x, -std::numeric_limits<double>::infinity()).item<float>(),
|
||||||
doctest::Approx(expected));
|
doctest::Approx(expected));
|
||||||
|
|
||||||
x = reshape(arange(9), {3, 3});
|
x = reshape(arange(9, float32), {3, 3});
|
||||||
|
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(x, 2.0, 0, false),
|
norm(x, 2.0, 0, false),
|
||||||
@ -129,10 +129,34 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
|||||||
CHECK_EQ(
|
CHECK_EQ(
|
||||||
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
norm(x, -1.0, std::vector<int>{1, 0}).item<float>(),
|
||||||
doctest::Approx(3.0));
|
doctest::Approx(3.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
|
||||||
|
doctest::Approx(14.226707));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),
|
||||||
|
doctest::Approx(14.226707));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
|
||||||
|
doctest::Approx(0.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -2.0, std::vector<int>{1, 0}, false, Device::cpu).item<float>(),
|
||||||
|
doctest::Approx(0.0));
|
||||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
CHECK_EQ(norm(x, 1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||||
CHECK_EQ(norm(x, 1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
CHECK_EQ(norm(x, 1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
CHECK_EQ(norm(x, -1.0, std::vector<int>{0, 1}, true).shape(), Shape{1, 1});
|
||||||
CHECK_EQ(norm(x, -1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
CHECK_EQ(norm(x, -1.0, std::vector<int>{1, 0}, true).shape(), Shape{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),
|
||||||
|
Shape{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),
|
||||||
|
Shape{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -2.0, std::vector<int>{0, 1}, true, Device::cpu).shape(),
|
||||||
|
Shape{1, 1});
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -2.0, std::vector<int>{1, 0}, true, Device::cpu).shape(),
|
||||||
|
Shape{1, 1});
|
||||||
|
|
||||||
CHECK_EQ(
|
CHECK_EQ(
|
||||||
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
norm(x, -1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||||
@ -140,8 +164,14 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
|||||||
CHECK_EQ(
|
CHECK_EQ(
|
||||||
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
norm(x, 1.0, std::vector<int>{-2, -1}, false).item<float>(),
|
||||||
doctest::Approx(15.0));
|
doctest::Approx(15.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, -2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),
|
||||||
|
doctest::Approx(0.0));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, 2.0, std::vector<int>{-2, -1}, false, Device::cpu).item<float>(),
|
||||||
|
doctest::Approx(14.226707));
|
||||||
|
|
||||||
x = reshape(arange(18), {2, 3, 3});
|
x = reshape(arange(18, float32), {2, 3, 3});
|
||||||
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
|
CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2}));
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(x, 3.0, 0),
|
norm(x, 3.0, 0),
|
||||||
@ -199,13 +229,31 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") {
|
|||||||
.item<bool>());
|
.item<bool>());
|
||||||
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
|
CHECK(allclose(norm(x, -1.0, std::vector<int>{1, 2}), array({9, 36}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 2.0, std::vector<int>{0, 1}, false, Device::cpu),
|
||||||
|
array({22.045408, 24.155825, 26.318918}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, 2.0, std::vector<int>{1, 2}, false, Device::cpu),
|
||||||
|
array({14.226707, 39.759212}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, -2.0, std::vector<int>{0, 1}, false, Device::cpu),
|
||||||
|
array({3, 2.7378995, 2.5128777}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, -2.0, std::vector<int>{1, 2}, false, Device::cpu),
|
||||||
|
array({4.979028e-16, 7.009628e-16}),
|
||||||
|
/* rtol = */ 1e-5,
|
||||||
|
/* atol = */ 1e-6)
|
||||||
|
.item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
||||||
array x({1, 2, 3});
|
array x({1, 2, 3});
|
||||||
CHECK_THROWS(norm(x, "fro"));
|
CHECK_THROWS(norm(x, "fro"));
|
||||||
|
|
||||||
x = reshape(arange(9), {3, 3});
|
x = reshape(arange(9, float32), {3, 3});
|
||||||
CHECK_THROWS(norm(x, "bad ord"));
|
CHECK_THROWS(norm(x, "bad ord"));
|
||||||
|
|
||||||
CHECK_EQ(
|
CHECK_EQ(
|
||||||
@ -214,8 +262,11 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
|||||||
CHECK_EQ(
|
CHECK_EQ(
|
||||||
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
|
norm(x, "fro", std::vector<int>{0, 1}).item<float>(),
|
||||||
doctest::Approx(14.2828568570857));
|
doctest::Approx(14.2828568570857));
|
||||||
|
CHECK_EQ(
|
||||||
|
norm(x, "nuc", std::vector<int>{0, 1}, false, Device::cpu).item<float>(),
|
||||||
|
doctest::Approx(15.491934));
|
||||||
|
|
||||||
x = reshape(arange(18), {2, 3, 3});
|
x = reshape(arange(18, float32), {2, 3, 3});
|
||||||
CHECK(allclose(
|
CHECK(allclose(
|
||||||
norm(x, "fro", std::vector<int>{0, 1}),
|
norm(x, "fro", std::vector<int>{0, 1}),
|
||||||
array({22.24859546, 24.31049156, 26.43860813}))
|
array({22.24859546, 24.31049156, 26.43860813}))
|
||||||
@ -240,6 +291,18 @@ TEST_CASE("[mlx.core.linalg.norm] string ord") {
|
|||||||
norm(x, "f", std::vector<int>{2, 1}),
|
norm(x, "f", std::vector<int>{2, 1}),
|
||||||
array({14.28285686, 39.7617907}))
|
array({14.28285686, 39.7617907}))
|
||||||
.item<bool>());
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "nuc", std::vector<int>{0, 1}, false, Device::cpu),
|
||||||
|
array({25.045408, 26.893724, 28.831797}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "nuc", std::vector<int>{1, 2}, false, Device::cpu),
|
||||||
|
array({15.491934, 40.211937}))
|
||||||
|
.item<bool>());
|
||||||
|
CHECK(allclose(
|
||||||
|
norm(x, "nuc", std::vector<int>{-2, -1}, false, Device::cpu),
|
||||||
|
array({15.491934, 40.211937}))
|
||||||
|
.item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test QR factorization") {
|
TEST_CASE("test QR factorization") {
|
||||||
@ -271,7 +334,7 @@ TEST_CASE("test SVD factorization") {
|
|||||||
|
|
||||||
const auto prng_key = random::key(42);
|
const auto prng_key = random::key(42);
|
||||||
const auto A = mlx::core::random::normal({5, 4}, prng_key);
|
const auto A = mlx::core::random::normal({5, 4}, prng_key);
|
||||||
const auto outs = linalg::svd(A, Device::cpu);
|
const auto outs = linalg::svd(A, true, Device::cpu);
|
||||||
CHECK_EQ(outs.size(), 3);
|
CHECK_EQ(outs.size(), 3);
|
||||||
|
|
||||||
const auto& U = outs[0];
|
const auto& U = outs[0];
|
||||||
@ -291,6 +354,15 @@ TEST_CASE("test SVD factorization") {
|
|||||||
CHECK_EQ(U.dtype(), float32);
|
CHECK_EQ(U.dtype(), float32);
|
||||||
CHECK_EQ(S.dtype(), float32);
|
CHECK_EQ(S.dtype(), float32);
|
||||||
CHECK_EQ(Vt.dtype(), float32);
|
CHECK_EQ(Vt.dtype(), float32);
|
||||||
|
|
||||||
|
// Test singular values
|
||||||
|
const auto& outs_sv = linalg::svd(A, false, Device::cpu);
|
||||||
|
const auto SV = outs_sv[0];
|
||||||
|
|
||||||
|
CHECK_EQ(SV.shape(), Shape{4});
|
||||||
|
CHECK_EQ(SV.dtype(), float32);
|
||||||
|
|
||||||
|
CHECK(allclose(norm(SV), norm(A, "fro")).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test matrix inversion") {
|
TEST_CASE("test matrix inversion") {
|
||||||
|
@ -466,15 +466,19 @@ TEST_CASE("test vmap scatter") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test vmap SVD") {
|
TEST_CASE("test vmap SVD") {
|
||||||
auto fun = [](std::vector<array> inputs) {
|
auto svd_full = [](std::vector<array> inputs) {
|
||||||
return linalg::svd(inputs.at(0), Device::cpu);
|
return linalg::svd(inputs.at(0), true, Device::cpu);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto svd_singular = [](std::vector<array> inputs) {
|
||||||
|
return linalg::svd(inputs.at(0), false, Device::cpu);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
|
auto a = astype(reshape(arange(24), {3, 4, 2}), float32);
|
||||||
|
|
||||||
// vmap over the second axis.
|
// vmap over the second axis.
|
||||||
{
|
{
|
||||||
auto out = vmap(fun, /* in_axes = */ {1})({a});
|
auto out = vmap(svd_full, /* in_axes = */ {1})({a});
|
||||||
const auto& U = out.at(0);
|
const auto& U = out.at(0);
|
||||||
const auto& S = out.at(1);
|
const auto& S = out.at(1);
|
||||||
const auto& Vt = out.at(2);
|
const auto& Vt = out.at(2);
|
||||||
@ -486,7 +490,7 @@ TEST_CASE("test vmap SVD") {
|
|||||||
|
|
||||||
// vmap over the third axis.
|
// vmap over the third axis.
|
||||||
{
|
{
|
||||||
auto out = vmap(fun, /* in_axes = */ {2})({a});
|
auto out = vmap(svd_full, /* in_axes = */ {2})({a});
|
||||||
const auto& U = out.at(0);
|
const auto& U = out.at(0);
|
||||||
const auto& S = out.at(1);
|
const auto& S = out.at(1);
|
||||||
const auto& Vt = out.at(2);
|
const auto& Vt = out.at(2);
|
||||||
@ -495,6 +499,21 @@ TEST_CASE("test vmap SVD") {
|
|||||||
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
|
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
|
||||||
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
|
CHECK_EQ(Vt.shape(), Shape{a.shape(2), a.shape(1), a.shape(1)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test singular values
|
||||||
|
{
|
||||||
|
auto out = vmap(svd_singular, /* in_axes = */ {1})({a});
|
||||||
|
const auto& S = out.at(0);
|
||||||
|
|
||||||
|
CHECK_EQ(S.shape(), Shape{a.shape(1), a.shape(2)});
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto out = vmap(svd_singular, /* in_axes = */ {2})({a});
|
||||||
|
const auto& S = out.at(0);
|
||||||
|
|
||||||
|
CHECK_EQ(S.shape(), Shape{a.shape(2), a.shape(0)});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test vmap dynamic slices") {
|
TEST_CASE("test vmap dynamic slices") {
|
||||||
|
Loading…
Reference in New Issue
Block a user