diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 9cce6cabb..e2c0e9f3f 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -4,6 +4,7 @@ #include #include +#include "mlx/dtype.h" #include "mlx/linalg.h" namespace mlx::core::linalg { @@ -87,7 +88,10 @@ inline array matrix_norm( bool keepdims, StreamOrDevice s) { if (ord == "f" || ord == "fro") { - return sqrt(sum(square(a, s), axis, keepdims, s), s); + if (is_complex(a.dtype())) + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); + else + return sqrt(sum(square(a, s), axis, keepdims, s), s); } else if (ord == "nuc") { throw std::runtime_error( "[linalg::norm] Nuclear norm not yet implemented.");