mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
complex for vector too
This commit is contained in:
parent
e87c2d4af3
commit
673af67c92
@ -13,6 +13,18 @@ Dtype at_least_float(const Dtype& d) {
|
||||
return is_floating_point(d) ? d : promote_types(d, float32);
|
||||
}
|
||||
|
||||
inline array l2_norm(
|
||||
const array& a,
|
||||
const std::vector<int>& axis,
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (is_complex(a.dtype())) {
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||
} else {
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
}
|
||||
}
|
||||
|
||||
inline array vector_norm(
|
||||
const array& a,
|
||||
const double ord,
|
||||
@ -25,7 +37,7 @@ inline array vector_norm(
|
||||
} else if (ord == 1.0) {
|
||||
return astype(sum(abs(a, s), axis, keepdims, s), dtype, s);
|
||||
} else if (ord == 2.0) {
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
return l2_norm(a, axis, keepdims, s);
|
||||
} else if (ord == std::numeric_limits<double>::infinity()) {
|
||||
return astype(max(abs(a, s), axis, keepdims, s), dtype, s);
|
||||
} else if (ord == -std::numeric_limits<double>::infinity()) {
|
||||
@ -88,10 +100,7 @@ inline array matrix_norm(
|
||||
bool keepdims,
|
||||
StreamOrDevice s) {
|
||||
if (ord == "f" || ord == "fro") {
|
||||
if (is_complex(a.dtype()))
|
||||
return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s);
|
||||
else
|
||||
return sqrt(sum(square(a, s), axis, keepdims, s), s);
|
||||
return l2_norm(a, axis, keepdims, s);
|
||||
} else if (ord == "nuc") {
|
||||
throw std::runtime_error(
|
||||
"[linalg::norm] Nuclear norm not yet implemented.");
|
||||
@ -115,7 +124,7 @@ array norm(
|
||||
throw std::invalid_argument(
|
||||
"[linalg::norm] Received too many axes for norm.");
|
||||
}
|
||||
return sqrt(sum(square(a, s), axis.value(), keepdims, s), s);
|
||||
return l2_norm(a, axis.value(), keepdims, s);
|
||||
}
|
||||
|
||||
array norm(
|
||||
|
@ -62,6 +62,33 @@ class TestLinalg(mlx_tests.MLXTestCase):
|
||||
with self.subTest(shape=shape, keepdims=keepdims):
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
def test_complex_norm(self):
|
||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||
x_np = np.random.uniform(size=shape).astype(
|
||||
np.float32
|
||||
) + 1j * np.random.uniform(size=shape).astype(np.float32)
|
||||
x_mx = mx.array(x_np)
|
||||
out_np = np.linalg.norm(x_np)
|
||||
out_mx = mx.linalg.norm(x_mx)
|
||||
with self.subTest(shape=shape):
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
for num_axes in range(1, len(shape)):
|
||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||
out_np = np.linalg.norm(x_np, axis=axis)
|
||||
out_mx = mx.linalg.norm(x_mx, axis=axis)
|
||||
with self.subTest(shape=shape, axis=axis):
|
||||
self.assertTrue(
|
||||
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
)
|
||||
|
||||
x_np = np.random.uniform(size=(4, 4)).astype(
|
||||
np.float32
|
||||
) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32)
|
||||
x_mx = mx.array(x_np)
|
||||
out_np = np.linalg.norm(x_np, ord="fro")
|
||||
out_mx = mx.linalg.norm(x_mx, ord="fro")
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user