mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
fixed a bug with frobenius norm of a complex-valued matrix
This commit is contained in:
parent
67e319488c
commit
e87c2d4af3
@ -4,6 +4,7 @@
|
|||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlx/dtype.h"
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
@ -87,7 +88,10 @@ inline array matrix_norm(
|
|||||||
bool keepdims,
|
bool keepdims,
|
||||||
StreamOrDevice s) {
|
StreamOrDevice s) {
|
||||||
if (ord == "f" || ord == "fro") {
|
if (ord == "f" || ord == "fro") {
|
||||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
if (is_complex(a.dtype()))
|
||||||
|
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||||
|
else
|
||||||
|
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||||
} else if (ord == "nuc") {
|
} else if (ord == "nuc") {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[linalg::norm] Nuclear norm not yet implemented.");
|
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||||
|
Loading…
Reference in New Issue
Block a user