From 67e319488cefd339a9c7e74b001672a90485e13b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 26 Dec 2023 14:47:56 -0800 Subject: [PATCH] fix python tests --- mlx/linalg.cpp | 29 ++++++++----- mlx/linalg.h | 17 ++++---- python/src/linalg.cpp | 8 ++-- python/tests/test_linalg.py | 84 +++++++++++++++---------------------- 4 files changed, 65 insertions(+), 73 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 61c9e8537..9cce6cabb 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -4,9 +4,7 @@ #include #include -#include "mlx/array.h" #include "mlx/linalg.h" -#include "mlx/ops.h" namespace mlx::core::linalg { @@ -48,25 +46,36 @@ inline array matrix_norm( auto dtype = at_least_float(a.dtype()); auto row_axis = axis[0]; auto col_axis = axis[1]; - if (!keepdims && col_axis > row_axis && col_axis > 0) { - col_axis -= 1; - } if (ord == -1.0) { + col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0); return astype( min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), dtype, s); } else if (ord == 1.0) { + col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0); return astype( max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), dtype, s); + } else if (ord == std::numeric_limits::infinity()) { + row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); + return astype( + max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), + dtype, + s); + } else if (ord == -std::numeric_limits::infinity()) { + row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); + return astype( + min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), + dtype, + s); } else if (ord == 2.0 || ord == -2.0) { throw std::runtime_error( "[linalg::norm] Singular value norms are not implemented."); } else { std::ostringstream msg; - msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm"; + msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm."; throw std::invalid_argument(msg.str()); } } @@ -78,13 +87,13 @@ inline array matrix_norm( bool keepdims, StreamOrDevice s) { if (ord == "f" || ord == "fro") { - return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); + return sqrt(sum(square(a, s), axis, keepdims, s), s); } else if (ord == "nuc") { throw std::runtime_error( "[linalg::norm] Nuclear norm not yet implemented."); } else { std::ostringstream msg; - msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm"; + msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm."; throw std::invalid_argument(msg.str()); } } @@ -100,7 +109,7 @@ array norm( if (axis.value().size() > 2) { throw std::invalid_argument( - "[linalg::norm] Received too many axes for norm"); + "[linalg::norm] Received too many axes for norm."); } return sqrt(sum(square(a, s), axis.value(), keepdims, s), s); } @@ -124,7 +133,7 @@ array norm( return matrix_norm(a, ord, ax, keepdims, s); } else { throw std::invalid_argument( - "[linalg::norm] Received too many axes for norm"); + "[linalg::norm] Received too many axes for norm."); } } diff --git a/mlx/linalg.h b/mlx/linalg.h index bf3b5e78c..80e484eb5 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -4,21 +4,22 @@ #include -#include "array.h" -#include "device.h" -#include "ops.h" -#include "stream.h" +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/ops.h" +#include "mlx/stream.h" namespace mlx::core::linalg { -/* +/** * Compute vector or matrix norms. * * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). * - If axis is not provided but ord is, then x must be either 1D or 2D. - * - If axis is provided, but ord is not, then the 2-norm is computed along the - * given axes. At most 2 axes can be specified. - * - If both axis and ord are provided, then the corresponding matrix of vector + * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm + * for matrices) is computed along the given axes. At most 2 axes can be + * specified. + * - If both axis and ord are provided, then the corresponding matrix or vector * norm is computed. At most 2 axes can be specified. */ array norm( diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index c193060db..ea5474a70 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -20,8 +20,8 @@ void init_linalg(py::module_& parent_module) { py::options options; options.disable_function_signatures(); - auto m = - parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); + auto m = parent_module.def_submodule( + "linalg", "mlx.core.linalg: linear algebra routines."); m.def( "norm", @@ -72,8 +72,8 @@ void init_linalg(py::module_& parent_module) { unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the 2-norm of ``a.flatten`` will be returned. ord (scalar or str, optional): Order of the norm (see table under ``Notes``). - If ``None``, the 2-norm will be computed along the given ``axis``. - Default: ``None``. + If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed + along the given ``axis``. Default: ``None``. axis (int or list(int), optional): If ``axis`` is an integer, it specifies the axis of ``a`` along which to compute the vector norms. If ``axis`` is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index ce1926de0..08a4510c8 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -11,74 +11,56 @@ import numpy as np class TestLinalg(mlx_tests.MLXTestCase): def test_norm(self): - vector_ords = [None, 0.5, 0, 1, 2, 3, -1, 1, float("inf"), -float("inf")] + vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")] matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] for shape in [(3,), (2, 3), (2, 3, 3)]: - x_mx = mx.arange(math.prod(shape)).reshape(shape) - x_np = np.arange(math.prod(shape)).reshape(shape) + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) # Test when at least one axis is provided for num_axes in range(1, len(shape)): - for axis in itertools.combinations(range(len(shape)), num_axes): - if num_axes == 1: - ords = vector_ords - else: - ords = matrix_ords - for keepdims in [True, False]: - # Test axis provided, no ord provided - out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) - # Test both ord and axis provided - for o in ords: - for keepdims in [True, False]: - if o: - out_np = np.linalg.norm( - x_np, ord=o, axis=axis, keepdims=keepdims - ) - out_mx = mx.linalg.norm( - x_mx, ord=o, axis=axis, keepdims=keepdims - ) - else: - out_np = np.linalg.norm( - x_np, axis=axis, keepdims=keepdims - ) - out_mx = mx.linalg.norm( - x_mx, axis=axis, keepdims=keepdims - ) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) - - # Test only axis provided - for shape in [(3,), (2, 3), (2, 3, 3)]: - x_mx = mx.arange(math.prod(shape)).reshape(shape) - x_np = np.arange(math.prod(shape)).reshape(shape) - - for num_axes in range(1, len(shape)): + if num_axes == 1: + ords = vector_ords + else: + ords = matrix_ords for axis in itertools.combinations(range(len(shape)), num_axes): for keepdims in [True, False]: - out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + for o in ords: + out_np = np.linalg.norm( + x_np, ord=o, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, ord=o, axis=axis, keepdims=keepdims + ) + with self.subTest( + shape=shape, ord=o, axis=axis, keepdims=keepdims + ): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) # Test only ord provided for shape in [(3,), (2, 3)]: - x_mx = mx.arange(math.prod(shape)).reshape(shape) - x_np = np.arange(math.prod(shape)).reshape(shape) + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) for o in [None, 1, -1, float("inf"), -float("inf")]: for keepdims in [True, False]: out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + with self.subTest(shape=shape, ord=o, keepdims=keepdims): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) # Test no ord and no axis provided for shape in [(3,), (2, 3), (2, 3, 3)]: - x_mx = mx.arange(math.prod(shape)).reshape(shape) - x_np = np.arange(math.prod(shape)).reshape(shape) - for o in [None, 1, -1, float("inf"), -float("inf")]: - for keepdims in [True, False]: - out_np = np.linalg.norm(x_np, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) + with self.subTest(shape=shape, keepdims=keepdims): + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) if __name__ == "__main__":