mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -102,8 +102,21 @@ inline array matrix_norm(
|
||||
dtype,
|
||||
s);
|
||||
} else if (ord == 2.0 || ord == -2.0) {
|
||||
throw std::runtime_error(
|
||||
"[linalg::norm] Singular value norms are not implemented.");
|
||||
row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];
|
||||
col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];
|
||||
auto a_matrix = (row_axis > col_axis)
|
||||
? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)
|
||||
: moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);
|
||||
a_matrix = svd(a_matrix, false, s).at(0);
|
||||
a_matrix = (ord == 2.0) ? max(a_matrix, -1, false, s)
|
||||
: min(a_matrix, -1, false, s);
|
||||
if (keepdims) {
|
||||
std::vector<int> sorted_axes = (row_axis < col_axis)
|
||||
? std::vector<int>{row_axis, col_axis}
|
||||
: std::vector<int>{col_axis, row_axis};
|
||||
a_matrix = expand_dims(a_matrix, sorted_axes, s);
|
||||
}
|
||||
return astype(a_matrix, dtype, s);
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm.";
|
||||
@@ -120,8 +133,19 @@ inline array matrix_norm(
|
||||
if (ord == "f" || ord == "fro") {
|
||||
return l2_norm(a, axis, keepdims, s);
|
||||
} else if (ord == "nuc") {
|
||||
throw std::runtime_error(
|
||||
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||
int row_axis = (axis[0] < 0) ? axis[0] + a.ndim() : axis[0];
|
||||
int col_axis = (axis[1] < 0) ? axis[1] + a.ndim() : axis[1];
|
||||
auto a_matrix = (row_axis > col_axis)
|
||||
? moveaxis(moveaxis(a, row_axis, -1, s), col_axis, -1, s)
|
||||
: moveaxis(moveaxis(a, col_axis, -1, s), row_axis, -2, s);
|
||||
a_matrix = sum(svd(a_matrix, false, s).at(0), -1, false, s);
|
||||
if (keepdims) {
|
||||
std::vector<int> sorted_axes = (row_axis < col_axis)
|
||||
? std::vector<int>{row_axis, col_axis}
|
||||
: std::vector<int>{col_axis, row_axis};
|
||||
a_matrix = expand_dims(a_matrix, sorted_axes, s);
|
||||
}
|
||||
return a_matrix;
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm.";
|
||||
@@ -214,7 +238,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return std::make_pair(out[0], out[1]);
|
||||
}
|
||||
|
||||
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
std::vector<array>
|
||||
svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::svd]");
|
||||
check_float(a.dtype(), "[linalg::svd]");
|
||||
|
||||
@@ -230,14 +255,22 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
const auto n = a.shape(-1);
|
||||
const auto rank = a.ndim();
|
||||
|
||||
auto u_shape = a.shape();
|
||||
u_shape[rank - 2] = m;
|
||||
u_shape[rank - 1] = m;
|
||||
|
||||
auto s_shape = a.shape();
|
||||
s_shape.pop_back();
|
||||
s_shape[rank - 2] = std::min(m, n);
|
||||
|
||||
if (!compute_uv) {
|
||||
return {array(
|
||||
std::move(s_shape),
|
||||
std::move(a.dtype()),
|
||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||
{a})};
|
||||
}
|
||||
|
||||
auto u_shape = a.shape();
|
||||
u_shape[rank - 2] = m;
|
||||
u_shape[rank - 1] = m;
|
||||
|
||||
auto vt_shape = a.shape();
|
||||
vt_shape[rank - 2] = n;
|
||||
vt_shape[rank - 1] = n;
|
||||
@@ -245,7 +278,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
return array::make_arrays(
|
||||
{u_shape, s_shape, vt_shape},
|
||||
{a.dtype(), a.dtype(), a.dtype()},
|
||||
std::make_shared<SVD>(to_stream(s)),
|
||||
std::make_shared<SVD>(to_stream(s), compute_uv),
|
||||
{a});
|
||||
}
|
||||
|
||||
@@ -323,7 +356,7 @@ array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
int m = a.shape(-2);
|
||||
int n = a.shape(-1);
|
||||
int k = std::min(m, n);
|
||||
auto outs = linalg::svd(a, s);
|
||||
auto outs = linalg::svd(a, true, s);
|
||||
array U = outs[0];
|
||||
array S = outs[1];
|
||||
array V = outs[2];
|
||||
|
||||
Reference in New Issue
Block a user