fixed a bug with frobenius norm of a complex-valued matrix

This commit is contained in:
Gabrijel Boduljak 2023-12-27 02:12:33 +01:00 committed by Awni Hannun
parent 67e319488c
commit e87c2d4af3

View File

@ -4,6 +4,7 @@
#include <ostream>
#include <vector>
#include "mlx/dtype.h"
#include "mlx/linalg.h"
namespace mlx::core::linalg {
@ -87,7 +88,10 @@ inline array matrix_norm(
bool keepdims,
StreamOrDevice s) {
if (ord == "f" || ord == "fro") {
return sqrt(sum(square(a, s), axis, keepdims, s), s);
if (is_complex(a.dtype()))
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
else
return sqrt(sum(square(a, s), axis, keepdims, s), s);
} else if (ord == "nuc") {
throw std::runtime_error(
"[linalg::norm] Nuclear norm not yet implemented.");