# Copyright © 2023 Apple Inc. import itertools import math import unittest import mlx.core as mx import mlx_tests import numpy as np class TestLinalg(mlx_tests.MLXTestCase): def test_norm(self): vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")] matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] for shape in [(3,), (2, 3), (2, 3, 3)]: x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) # Test when at least one axis is provided for num_axes in range(1, len(shape)): if num_axes == 1: ords = vector_ords else: ords = matrix_ords for axis in itertools.combinations(range(len(shape)), num_axes): for keepdims in [True, False]: for o in ords: out_np = np.linalg.norm( x_np, ord=o, axis=axis, keepdims=keepdims ) out_mx = mx.linalg.norm( x_mx, ord=o, axis=axis, keepdims=keepdims ) with self.subTest( shape=shape, ord=o, axis=axis, keepdims=keepdims ): self.assertTrue( np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) ) # Test only ord provided for shape in [(3,), (2, 3)]: x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) for o in [None, 1, -1, float("inf"), -float("inf")]: for keepdims in [True, False]: out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) with self.subTest(shape=shape, ord=o, keepdims=keepdims): self.assertTrue( np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) ) # Test no ord and no axis provided for shape in [(3,), (2, 3), (2, 3, 3)]: x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) for keepdims in [True, False]: out_np = np.linalg.norm(x_np, keepdims=keepdims) out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) with self.subTest(shape=shape, keepdims=keepdims): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) def test_complex_norm(self): for shape in [(3,), (2, 3), (2, 3, 3)]: x_np = np.random.uniform(size=shape).astype( np.float32 ) + 1j * np.random.uniform(size=shape).astype(np.float32) x_mx = mx.array(x_np) out_np = np.linalg.norm(x_np) out_mx = mx.linalg.norm(x_mx) with self.subTest(shape=shape): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) for num_axes in range(1, len(shape)): for axis in itertools.combinations(range(len(shape)), num_axes): out_np = np.linalg.norm(x_np, axis=axis) out_mx = mx.linalg.norm(x_mx, axis=axis) with self.subTest(shape=shape, axis=axis): self.assertTrue( np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) ) x_np = np.random.uniform(size=(4, 4)).astype( np.float32 ) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32) x_mx = mx.array(x_np) out_np = np.linalg.norm(x_np, ord="fro") out_mx = mx.linalg.norm(x_mx, ord="fro") self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) if __name__ == "__main__": unittest.main()