Adds nuclear norm support (#1894)

* adjust norm unit test tolerance
This commit is contained in:
Abe Leininger
2025-03-04 15:26:02 -06:00
committed by GitHub
parent 9680f72cca
commit 3835a428c5
11 changed files with 260 additions and 55 deletions

View File

@@ -8,7 +8,7 @@
namespace mlx::core {
template <typename T>
void svd_impl(const array& a, array& u, array& s, array& vt) {
void svd_impl(const array& a, T* u_data, T* s_data, T* vt_data) {
// Lapack uses the column-major convention. To avoid having to transpose
// the input and then transpose the outputs, we swap the indices/sizes of the
// matrices and take advantage of the following identity (see
@@ -35,13 +35,8 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
array in(a.shape(), a.dtype(), nullptr, {});
copy(a, in, a.flags().row_contiguous ? CopyType::Vector : CopyType::General);
// Allocate outputs.
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
static constexpr auto job_u = "V";
static constexpr auto job_vt = "V";
auto job_u = (u_data && vt_data) ? "V" : "N";
auto job_vt = (u_data && vt_data) ? "V" : "N";
static constexpr auto range = "A";
// Will contain the number of singular values after the call has returned.
@@ -56,6 +51,7 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
static const int ignored_int = 0;
static const T ignored_float = 0;
static T ignored_output = 0;
int info;
@@ -109,12 +105,12 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
/* il = */ &ignored_int,
/* iu = */ &ignored_int,
/* ns = */ &ns,
/* s = */ s.data<T>() + K * i,
/* s = */ s_data + K * i,
// According to the identity above, lapack will write Vᵀᵀ as U.
/* u = */ vt.data<T>() + N * N * i,
/* u = */ vt_data ? vt_data + N * N * i : &ignored_output,
/* ldu = */ &ldu,
// According to the identity above, lapack will write Uᵀ as Vᵀ.
/* vt = */ u.data<T>() + M * M * i,
/* vt = */ u_data ? u_data + M * M * i : &ignored_output,
/* ldvt = */ &ldvt,
/* work = */ static_cast<T*>(scratch.buffer.raw_ptr()),
/* lwork = */ &lwork,
@@ -136,15 +132,36 @@ void svd_impl(const array& a, array& u, array& s, array& vt) {
}
}
template <typename T>
void compute_svd(const array& a, bool compute_uv, std::vector<array>& outputs) {
if (compute_uv) {
array& u = outputs[0];
array& s = outputs[1];
array& vt = outputs[2];
u.set_data(allocator::malloc_or_wait(u.nbytes()));
s.set_data(allocator::malloc_or_wait(s.nbytes()));
vt.set_data(allocator::malloc_or_wait(vt.nbytes()));
svd_impl<T>(a, u.data<T>(), s.data<T>(), vt.data<T>());
} else {
array& s = outputs[0];
s.set_data(allocator::malloc_or_wait(s.nbytes()));
svd_impl<T>(a, nullptr, s.data<T>(), nullptr);
}
}
void SVD::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
switch (inputs[0].dtype()) {
case float32:
svd_impl<float>(inputs[0], outputs[0], outputs[1], outputs[2]);
compute_svd<float>(inputs[0], compute_uv_, outputs);
break;
case float64:
svd_impl<double>(inputs[0], outputs[0], outputs[1], outputs[2]);
compute_svd<double>(inputs[0], compute_uv_, outputs);
break;
default:
throw std::runtime_error(

View File

@@ -102,8 +102,21 @@ inline array matrix_norm(
dtype,
s);
} else if (ord == 2.0 || ord == -2.0) {
throw std::runtime_error(
"[linalg::norm] Singular value norms are not implemented.");
row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];
col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];
auto a_matrix = (row_axis > col_axis)
? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)
: moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);
a_matrix = svd(a_matrix, false, s).at(0);
a_matrix = (ord == 2.0) ? max(a_matrix, -1, false, s)
: min(a_matrix, -1, false, s);
if (keepdims) {
std::vector<int> sorted_axes = (row_axis < col_axis)
? std::vector<int>{row_axis, col_axis}
: std::vector<int>{col_axis, row_axis};
a_matrix = expand_dims(a_matrix, sorted_axes, s);
}
return astype(a_matrix, dtype, s);
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
@@ -120,8 +133,19 @@ inline array matrix_norm(
if (ord == "f" || ord == "fro") {
return l2_norm(a, axis, keepdims, s);
} else if (ord == "nuc") {
throw std::runtime_error(
"[linalg::norm] Nuclear norm not yet implemented.");
int row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];
int col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];
auto a_matrix = (row_axis > col_axis)
? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)
: moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);
a_matrix = sum(svd(a_matrix, false, s).at(0), -1, false, s);
if (keepdims) {
std::vector<int> sorted_axes = (row_axis < col_axis)
? std::vector<int>{row_axis, col_axis}
: std::vector<int>{col_axis, row_axis};
a_matrix = expand_dims(a_matrix, sorted_axes, s);
}
return a_matrix;
} else {
std::ostringstream msg;
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
@@ -214,7 +238,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
return std::make_pair(out[0], out[1]);
}
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
std::vector<array>
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
check_cpu_stream(s, "[linalg::svd]");
check_float(a.dtype(), "[linalg::svd]");
@@ -230,14 +255,22 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
const auto n = a.shape(-1);
const auto rank = a.ndim();
auto u_shape = a.shape();
u_shape[rank - 2] = m;
u_shape[rank - 1] = m;
auto s_shape = a.shape();
s_shape.pop_back();
s_shape[rank - 2] = std::min(m, n);
if (!compute_uv) {
return {array(
std::move(s_shape),
std::move(a.dtype()),
std::make_shared<SVD>(to_stream(s), compute_uv),
{a})};
}
auto u_shape = a.shape();
u_shape[rank - 2] = m;
u_shape[rank - 1] = m;
auto vt_shape = a.shape();
vt_shape[rank - 2] = n;
vt_shape[rank - 1] = n;
@@ -245,7 +278,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
return array::make_arrays(
{u_shape, s_shape, vt_shape},
{a.dtype(), a.dtype(), a.dtype()},
std::make_shared<SVD>(to_stream(s)),
std::make_shared<SVD>(to_stream(s), compute_uv),
{a});
}
@@ -323,7 +356,7 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
int m = a.shape(-2);
int n = a.shape(-1);
int k = std::min(m, n);
auto outs = linalg::svd(a, s);
auto outs = linalg::svd(a, true, s);
array U = outs[0];
array S = outs[1];
array V = outs[2];

View File

@@ -62,7 +62,11 @@ norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
std::vector<array> svd(const array& a, StreamOrDevice s = {});
std::vector<array>
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */);
inline std::vector<array> svd(const array& a, StreamOrDevice s = {}) {
return svd(a, true, s);
}
array inv(const array& a, StreamOrDevice s = {});

View File

@@ -4940,7 +4940,8 @@ std::pair<std::vector<array>, std::vector<int>> SVD::vmap(
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::svd(a, stream())}, {ax, ax, ax}};
std::vector<int> new_axes(compute_uv_ ? 3 : 1, ax);
return {linalg::svd(a, compute_uv_, stream()), std::move(new_axes)};
}
std::pair<std::vector<array>, std::vector<int>> Inverse::vmap(

View File

@@ -2287,7 +2287,8 @@ class QRF : public Primitive {
/* SVD primitive. */
class SVD : public Primitive {
public:
explicit SVD(Stream stream) : Primitive(stream) {}
explicit SVD(Stream stream, bool compute_uv)
: Primitive(stream), compute_uv_(compute_uv) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
@@ -2296,6 +2297,12 @@ class SVD : public Primitive {
DEFINE_VMAP()
DEFINE_PRINT(SVD)
auto state() const {
return compute_uv_;
}
private:
bool compute_uv_;
};
/* Matrix inversion primitive. */

View File

@@ -244,7 +244,7 @@ array multivariate_normal(
// Compute the square-root of the covariance matrix, using the SVD
auto covariance = astype(cov, float32, stream);
auto SVD = linalg::svd(covariance, stream);
auto SVD = linalg::svd(covariance, true, stream);
auto std = astype(
matmul(
multiply(