complex for vector too

This commit is contained in:
Awni Hannun 2023-12-26 19:40:20 -08:00
parent e87c2d4af3
commit 673af67c92
2 changed files with 42 additions and 6 deletions

View File

@ -13,6 +13,18 @@ Dtype at_least_float(const Dtype& d) {
return is_floating_point(d) ? d : promote_types(d, float32);
}
inline array l2_norm(
const array& a,
const std::vector<int>& axis,
bool keepdims,
StreamOrDevice s) {
if (is_complex(a.dtype())) {
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
} else {
return sqrt(sum(square(a, s), axis, keepdims, s), s);
}
}
inline array vector_norm(
const array& a,
const double ord,
@ -25,7 +37,7 @@ inline array vector_norm(
} else if (ord == 1.0) {
return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
} else if (ord == 2.0) {
return sqrt(sum(square(a, s), axis, keepdims, s), s);
return l2_norm(a, axis, keepdims, s);
} else if (ord == std::numeric_limits<double>::infinity()) {
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
} else if (ord == -std::numeric_limits<double>::infinity()) {
@ -88,10 +100,7 @@ inline array matrix_norm(
bool keepdims,
StreamOrDevice s) {
if (ord == "f" || ord == "fro") {
if (is_complex(a.dtype()))
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
else
return sqrt(sum(square(a, s), axis, keepdims, s), s);
return l2_norm(a, axis, keepdims, s);
} else if (ord == "nuc") {
throw std::runtime_error(
"[linalg::norm] Nuclear norm not yet implemented.");
@ -115,7 +124,7 @@ array norm(
throw std::invalid_argument(
"[linalg::norm] Received too many axes for norm.");
}
return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
return l2_norm(a, axis.value(), keepdims, s);
}
array norm(

View File

@ -62,6 +62,33 @@ class TestLinalg(mlx_tests.MLXTestCase):
with self.subTest(shape=shape, keepdims=keepdims):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
def test_complex_norm(self):
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_np = np.random.uniform(size=shape).astype(
np.float32
) + 1j * np.random.uniform(size=shape).astype(np.float32)
x_mx = mx.array(x_np)
out_np = np.linalg.norm(x_np)
out_mx = mx.linalg.norm(x_mx)
with self.subTest(shape=shape):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
for num_axes in range(1, len(shape)):
for axis in itertools.combinations(range(len(shape)), num_axes):
out_np = np.linalg.norm(x_np, axis=axis)
out_mx = mx.linalg.norm(x_mx, axis=axis)
with self.subTest(shape=shape, axis=axis):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
x_np = np.random.uniform(size=(4, 4)).astype(
np.float32
) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32)
x_mx = mx.array(x_np)
out_np = np.linalg.norm(x_np, ord="fro")
out_mx = mx.linalg.norm(x_mx, ord="fro")
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
if __name__ == "__main__":
unittest.main()