From 26bb16e768a9fb2dafb21974b630ec103f36f4b6 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Fri, 22 Dec 2023 00:33:36 +0100 Subject: [PATCH] added some tests for python bindings --- python/tests/test_linalg.py | 41 +++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 python/tests/test_linalg.py diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py new file mode 100644 index 000000000..26e9587c5 --- /dev/null +++ b/python/tests/test_linalg.py @@ -0,0 +1,41 @@ +# Copyright © 2023 Apple Inc. + +import itertools +import unittest + +import mlx.core as mx +import mlx_tests +import numpy as np + + +class TestLinalg(mlx_tests.MLXTestCase): + def test_norm(self): + def check_mx_np(a_mx, a_np): + self.assertTrue(np.allclose(a_np, a_mx, atol=1e-5, rtol=1e-6)) + + x_mx = mx.arange(18).reshape((2, 3, 3)) + x_np = np.arange(18).reshape((2, 3, 3)) + + for num_axes in range(1, 3): + for axis in itertools.combinations(range(3), num_axes): + if num_axes == 1: + ords = [None, 0.5, 0, 1, 2, 3, -1, 1] + else: + ords = [None, "fro", -1, 1] + for o in ords: + for keepdims in [True, False]: + if o: + out_np = np.linalg.norm( + x_np, ord=o, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, ord=o, axis=axis, keepdims=keepdims + ) + else: + out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + + +if __name__ == "__main__": + unittest.main()