mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
linalg.norm (#187)
* implemented vector_norm in cpp added linalg to mlx * implemented vector_norm python binding * renamed vector_norm to norm, implemented norm without provided ord * completed the implementation of the norm * added tests * removed unused import in linalg.cpp * updated python bindings * added some tests for python bindings * handling inf, -inf as numpy does, more extensive tests of compatibility with numpy * added better docs and examples * refactored mlx.linalg.norm bindings * reused existing util for implementation of linalg.norm * more tests * fixed a bug with no ord and axis provided * removed unused imports * some style and API consistency updates to linalg norm * remove unused includes * fix python tests * fixed a bug with frobenius norm of a complex-valued matrix * complex for vector too --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:

committed by
GitHub

parent
447bc089b9
commit
6b0d30bb85
94
python/tests/test_linalg.py
Normal file
94
python/tests/test_linalg.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestLinalg(mlx_tests.MLXTestCase):
|
||||
def test_norm(self):
|
||||
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
|
||||
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
|
||||
|
||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
# Test when at least one axis is provided
|
||||
for num_axes in range(1, len(shape)):
|
||||
if num_axes == 1:
|
||||
ords = vector_ords
|
||||
else:
|
||||
ords = matrix_ords
|
||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||
for keepdims in [True, False]:
|
||||
for o in ords:
|
||||
out_np = np.linalg.norm(
|
||||
x_np, ord=o, axis=axis, keepdims=keepdims
|
||||
)
|
||||
out_mx = mx.linalg.norm(
|
||||
x_mx, ord=o, axis=axis, keepdims=keepdims
|
||||
)
|
||||
with self.subTest(
|
||||
shape=shape, ord=o, axis=axis, keepdims=keepdims
|
||||
):
|
||||
self.assertTrue(
|
||||
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
)
|
||||
|
||||
# Test only ord provided
|
||||
for shape in [(3,), (2, 3)]:
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
for o in [None, 1, -1, float("inf"), -float("inf")]:
|
||||
for keepdims in [True, False]:
|
||||
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
|
||||
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
|
||||
with self.subTest(shape=shape, ord=o, keepdims=keepdims):
|
||||
self.assertTrue(
|
||||
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
)
|
||||
|
||||
# Test no ord and no axis provided
|
||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
|
||||
for keepdims in [True, False]:
|
||||
out_np = np.linalg.norm(x_np, keepdims=keepdims)
|
||||
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
|
||||
with self.subTest(shape=shape, keepdims=keepdims):
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
def test_complex_norm(self):
|
||||
for shape in [(3,), (2, 3), (2, 3, 3)]:
|
||||
x_np = np.random.uniform(size=shape).astype(
|
||||
np.float32
|
||||
) + 1j * np.random.uniform(size=shape).astype(np.float32)
|
||||
x_mx = mx.array(x_np)
|
||||
out_np = np.linalg.norm(x_np)
|
||||
out_mx = mx.linalg.norm(x_mx)
|
||||
with self.subTest(shape=shape):
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
for num_axes in range(1, len(shape)):
|
||||
for axis in itertools.combinations(range(len(shape)), num_axes):
|
||||
out_np = np.linalg.norm(x_np, axis=axis)
|
||||
out_mx = mx.linalg.norm(x_mx, axis=axis)
|
||||
with self.subTest(shape=shape, axis=axis):
|
||||
self.assertTrue(
|
||||
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
|
||||
)
|
||||
|
||||
x_np = np.random.uniform(size=(4, 4)).astype(
|
||||
np.float32
|
||||
) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32)
|
||||
x_mx = mx.array(x_np)
|
||||
out_np = np.linalg.norm(x_np, ord="fro")
|
||||
out_mx = mx.linalg.norm(x_mx, ord="fro")
|
||||
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user