mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
fixed a bug with frobenius norm of a complex-valued matrix
This commit is contained in:
parent
67e319488c
commit
e87c2d4af3
@ -4,6 +4,7 @@
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/dtype.h"
|
||||
#include "mlx/linalg.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
@ -87,7 +88,10 @@ inline array matrix_norm(
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (ord == "f" || ord == "fro") {
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
if (is_complex(a.dtype()))
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||
else
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
} else if (ord == "nuc") {
|
||||
throw std::runtime_error(
|
||||
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||
|
Loading…
Reference in New Issue
Block a user