diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index e2c0e9f3f..7e7264e3f 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -13,6 +13,18 @@ Dtype at_least_float(const Dtype& d) { return is_floating_point(d) ? d : promote_types(d, float32); } +inline array l2_norm( + const array& a, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (is_complex(a.dtype())) { + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); + } else { + return sqrt(sum(square(a, s), axis, keepdims, s), s); + } +} + inline array vector_norm( const array& a, const double ord, @@ -25,7 +37,7 @@ inline array vector_norm( } else if (ord == 1.0) { return astype(sum(abs(a, s), axis, keepdims, s), dtype, s); } else if (ord == 2.0) { - return sqrt(sum(square(a, s), axis, keepdims, s), s); + return l2_norm(a, axis, keepdims, s); } else if (ord == std::numeric_limits::infinity()) { return astype(max(abs(a, s), axis, keepdims, s), dtype, s); } else if (ord == -std::numeric_limits::infinity()) { @@ -88,10 +100,7 @@ inline array matrix_norm( bool keepdims, StreamOrDevice s) { if (ord == "f" || ord == "fro") { - if (is_complex(a.dtype())) - return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); - else - return sqrt(sum(square(a, s), axis, keepdims, s), s); + return l2_norm(a, axis, keepdims, s); } else if (ord == "nuc") { throw std::runtime_error( "[linalg::norm] Nuclear norm not yet implemented."); @@ -115,7 +124,7 @@ array norm( throw std::invalid_argument( "[linalg::norm] Received too many axes for norm."); } - return sqrt(sum(square(a, s), axis.value(), keepdims, s), s); + return l2_norm(a, axis.value(), keepdims, s); } array norm( diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 08a4510c8..ac86c1e11 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -62,6 +62,33 @@ class TestLinalg(mlx_tests.MLXTestCase): with self.subTest(shape=shape, keepdims=keepdims): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + def test_complex_norm(self): + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_np = np.random.uniform(size=shape).astype( + np.float32 + ) + 1j * np.random.uniform(size=shape).astype(np.float32) + x_mx = mx.array(x_np) + out_np = np.linalg.norm(x_np) + out_mx = mx.linalg.norm(x_mx) + with self.subTest(shape=shape): + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + for num_axes in range(1, len(shape)): + for axis in itertools.combinations(range(len(shape)), num_axes): + out_np = np.linalg.norm(x_np, axis=axis) + out_mx = mx.linalg.norm(x_mx, axis=axis) + with self.subTest(shape=shape, axis=axis): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) + + x_np = np.random.uniform(size=(4, 4)).astype( + np.float32 + ) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32) + x_mx = mx.array(x_np) + out_np = np.linalg.norm(x_np, ord="fro") + out_mx = mx.linalg.norm(x_mx, ord="fro") + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + if __name__ == "__main__": unittest.main()