mlx/python/tests/test_linalg.py
Gabrijel Boduljak 6b0d30bb85
linalg.norm (#187)
* implemented vector_norm in cpp

added linalg to mlx

* implemented vector_norm python binding

* renamed vector_norm to norm, implemented norm without provided ord

* completed the implementation of the norm

* added tests

* removed unused import in linalg.cpp

* updated python bindings

* added some tests for python bindings

* handling inf, -inf as numpy does, more extensive tests of compatibility with numpy

* added better docs and examples

* refactored mlx.linalg.norm bindings

* reused existing util for implementation of linalg.norm

* more tests

* fixed a bug with no ord and axis provided

* removed unused imports

* some style and API consistency updates to linalg norm

* remove unused includes

* fix python tests

* fixed a bug with frobenius norm of a complex-valued matrix

* complex for vector too

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2023-12-26 19:42:04 -08:00

95 lines
4.1 KiB
Python

# Copyright © 2023 Apple Inc.
import itertools
import math
import unittest
import mlx.core as mx
import mlx_tests
import numpy as np
class TestLinalg(mlx_tests.MLXTestCase):
def test_norm(self):
vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")]
matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")]
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
# Test when at least one axis is provided
for num_axes in range(1, len(shape)):
if num_axes == 1:
ords = vector_ords
else:
ords = matrix_ords
for axis in itertools.combinations(range(len(shape)), num_axes):
for keepdims in [True, False]:
for o in ords:
out_np = np.linalg.norm(
x_np, ord=o, axis=axis, keepdims=keepdims
)
out_mx = mx.linalg.norm(
x_mx, ord=o, axis=axis, keepdims=keepdims
)
with self.subTest(
shape=shape, ord=o, axis=axis, keepdims=keepdims
):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
# Test only ord provided
for shape in [(3,), (2, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
for o in [None, 1, -1, float("inf"), -float("inf")]:
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims)
with self.subTest(shape=shape, ord=o, keepdims=keepdims):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
# Test no ord and no axis provided
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape)
x_np = np.arange(1, math.prod(shape) + 1).reshape(shape)
for keepdims in [True, False]:
out_np = np.linalg.norm(x_np, keepdims=keepdims)
out_mx = mx.linalg.norm(x_mx, keepdims=keepdims)
with self.subTest(shape=shape, keepdims=keepdims):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
def test_complex_norm(self):
for shape in [(3,), (2, 3), (2, 3, 3)]:
x_np = np.random.uniform(size=shape).astype(
np.float32
) + 1j * np.random.uniform(size=shape).astype(np.float32)
x_mx = mx.array(x_np)
out_np = np.linalg.norm(x_np)
out_mx = mx.linalg.norm(x_mx)
with self.subTest(shape=shape):
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
for num_axes in range(1, len(shape)):
for axis in itertools.combinations(range(len(shape)), num_axes):
out_np = np.linalg.norm(x_np, axis=axis)
out_mx = mx.linalg.norm(x_mx, axis=axis)
with self.subTest(shape=shape, axis=axis):
self.assertTrue(
np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)
)
x_np = np.random.uniform(size=(4, 4)).astype(
np.float32
) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32)
x_mx = mx.array(x_np)
out_np = np.linalg.norm(x_np, ord="fro")
out_mx = mx.linalg.norm(x_mx, ord="fro")
self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6))
if __name__ == "__main__":
unittest.main()