mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	linalg.norm (#187)
* implemented vector_norm in cpp added linalg to mlx * implemented vector_norm python binding * renamed vector_norm to norm, implemented norm without provided ord * completed the implementation of the norm * added tests * removed unused import in linalg.cpp * updated python bindings * added some tests for python bindings * handling inf, -inf as numpy does, more extensive tests of compatibility with numpy * added better docs and examples * refactored mlx.linalg.norm bindings * reused existing util for implementation of linalg.norm * more tests * fixed a bug with no ord and axis provided * removed unused imports * some style and API consistency updates to linalg norm * remove unused includes * fix python tests * fixed a bug with frobenius norm of a complex-valued matrix * complex for vector too --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		 Gabrijel Boduljak
					Gabrijel Boduljak
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							447bc089b9
						
					
				
				
					commit
					6b0d30bb85
				
			| @@ -11,6 +11,7 @@ pybind11_add_module( | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp | ||||
| ) | ||||
|  | ||||
| if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) | ||||
|   | ||||
							
								
								
									
										180
									
								
								python/src/linalg.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								python/src/linalg.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,180 @@ | ||||
| // Copyright © 2023 Apple Inc. | ||||
|  | ||||
| #include <variant> | ||||
|  | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/stl.h> | ||||
|  | ||||
| #include "mlx/linalg.h" | ||||
|  | ||||
| #include "python/src/load.h" | ||||
| #include "python/src/utils.h" | ||||
|  | ||||
| namespace py = pybind11; | ||||
| using namespace py::literals; | ||||
|  | ||||
| using namespace mlx::core; | ||||
| using namespace mlx::core::linalg; | ||||
|  | ||||
| void init_linalg(py::module_& parent_module) { | ||||
|   py::options options; | ||||
|   options.disable_function_signatures(); | ||||
|  | ||||
|   auto m = parent_module.def_submodule( | ||||
|       "linalg", "mlx.core.linalg: linear algebra routines."); | ||||
|  | ||||
|   m.def( | ||||
|       "norm", | ||||
|       [](const array& a, | ||||
|          const std::variant<std::monostate, int, double, std::string>& ord_, | ||||
|          const std::variant<std::monostate, int, std::vector<int>>& axis_, | ||||
|          const bool keepdims, | ||||
|          const StreamOrDevice stream) { | ||||
|         std::optional<std::vector<int>> axis = std::nullopt; | ||||
|         if (auto pv = std::get_if<int>(&axis_); pv) { | ||||
|           axis = std::vector<int>{*pv}; | ||||
|         } else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) { | ||||
|           axis = *pv; | ||||
|         } | ||||
|  | ||||
|         if (std::holds_alternative<std::monostate>(ord_)) { | ||||
|           return norm(a, axis, keepdims, stream); | ||||
|         } else { | ||||
|           if (auto pv = std::get_if<std::string>(&ord_); pv) { | ||||
|             return norm(a, *pv, axis, keepdims, stream); | ||||
|           } | ||||
|           double ord; | ||||
|           if (auto pv = std::get_if<int>(&ord_); pv) { | ||||
|             ord = *pv; | ||||
|           } else { | ||||
|             ord = std::get<double>(ord_); | ||||
|           } | ||||
|           return norm(a, ord, axis, keepdims, stream); | ||||
|         } | ||||
|       }, | ||||
|       "a"_a, | ||||
|       py::pos_only(), | ||||
|       "ord"_a = none, | ||||
|       "axis"_a = none, | ||||
|       "keepdims"_a = false, | ||||
|       py::kw_only(), | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|         Matrix or vector norm. | ||||
|  | ||||
|         This function computes vector or  matrix norms depending on the value of | ||||
|         the ``ord`` and ``axis`` parameters. | ||||
|  | ||||
|         Args: | ||||
|           a (array): Input array.  If ``axis`` is ``None``, ``a`` must be 1-D or 2-D, | ||||
|             unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the | ||||
|             2-norm of ``a.flatten`` will be returned. | ||||
|           ord (scalar or str, optional): Order of the norm (see table under ``Notes``). | ||||
|             If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed | ||||
|             along the given ``axis``.  Default: ``None``. | ||||
|           axis (int or list(int), optional): If ``axis`` is an integer, it specifies the | ||||
|             axis of ``a`` along which to compute the vector norms.  If ``axis`` is a | ||||
|             2-tuple, it specifies the axes that hold 2-D matrices, and the matrix | ||||
|             norms of these matrices are computed. If `axis` is ``None`` then | ||||
|             either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is | ||||
|             2-D) is returned. Default: ``None``. | ||||
|           keepdims (bool, optional): If ``True``, the axes which are normed over are | ||||
|             left in the result as dimensions with size one. Default ``False``. | ||||
|  | ||||
|         Returns: | ||||
|           array: The output containing the norm(s). | ||||
|  | ||||
|         Notes: | ||||
|           For values of ``ord < 1``, the result is, strictly speaking, not a | ||||
|           mathematical norm, but it may still be useful for various numerical | ||||
|           purposes. | ||||
|  | ||||
|           The following norms can be calculated: | ||||
|  | ||||
|           =====  ============================  ========================== | ||||
|           ord    norm for matrices             norm for vectors | ||||
|           =====  ============================  ========================== | ||||
|           None   Frobenius norm                2-norm | ||||
|           'fro'  Frobenius norm                -- | ||||
|           inf    max(sum(abs(x), axis=1))      max(abs(x)) | ||||
|           -inf   min(sum(abs(x), axis=1))      min(abs(x)) | ||||
|           0      --                            sum(x != 0) | ||||
|           1      max(sum(abs(x), axis=0))      as below | ||||
|           -1     min(sum(abs(x), axis=0))      as below | ||||
|           2      2-norm (largest sing. value)  as below | ||||
|           -2     smallest singular value       as below | ||||
|           other  --                            sum(abs(x)**ord)**(1./ord) | ||||
|           =====  ============================  ========================== | ||||
|  | ||||
|           .. warning:: | ||||
|             Nuclear norm and norms based on singular values are not yet implemented. | ||||
|  | ||||
|           The Frobenius norm is given by [1]_: | ||||
|  | ||||
|               :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` | ||||
|  | ||||
|           The nuclear norm is the sum of the singular values. | ||||
|  | ||||
|           Both the Frobenius and nuclear norm orders are only defined for | ||||
|           matrices and raise a ``ValueError`` when ``a.ndim != 2``. | ||||
|  | ||||
|         References: | ||||
|           .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, | ||||
|                  Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 | ||||
|  | ||||
|         Examples: | ||||
|           >>> import mlx.core as mx | ||||
|           >>> from mlx.core import linalg as la | ||||
|           >>> a = mx.arange(9) - 4 | ||||
|           >>> a | ||||
|           array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) | ||||
|           >>> b = a.reshape((3,3)) | ||||
|           >>> b | ||||
|           array([[-4, -3, -2], | ||||
|                  [-1,  0,  1], | ||||
|                  [ 2,  3,  4]], dtype=int32) | ||||
|           >>> la.norm(a) | ||||
|           array(7.74597, dtype=float32) | ||||
|           >>> la.norm(b) | ||||
|           array(7.74597, dtype=float32) | ||||
|           >>> la.norm(b, 'fro') | ||||
|           array(7.74597, dtype=float32) | ||||
|           >>> la.norm(a, float("inf")) | ||||
|           array(4, dtype=float32) | ||||
|           >>> la.norm(b, float("inf")) | ||||
|           array(9, dtype=float32) | ||||
|           >>> la.norm(a, -float("inf")) | ||||
|           array(0, dtype=float32) | ||||
|           >>> la.norm(b, -float("inf")) | ||||
|           array(2, dtype=float32) | ||||
|           >>> la.norm(a, 1) | ||||
|           array(20, dtype=float32) | ||||
|           >>> la.norm(b, 1) | ||||
|           array(7, dtype=float32) | ||||
|           >>> la.norm(a, -1) | ||||
|           array(0, dtype=float32) | ||||
|           >>> la.norm(b, -1) | ||||
|           array(6, dtype=float32) | ||||
|           >>> la.norm(a, 2) | ||||
|           array(7.74597, dtype=float32) | ||||
|           >>> la.norm(a, 3) | ||||
|           array(5.84804, dtype=float32) | ||||
|           >>> la.norm(a, -3) | ||||
|           array(0, dtype=float32) | ||||
|           >>> c = mx.array([[ 1, 2, 3], | ||||
|           ...               [-1, 1, 4]]) | ||||
|           >>> la.norm(c, axis=0) | ||||
|           array([1.41421, 2.23607, 5], dtype=float32) | ||||
|           >>> la.norm(c, axis=1) | ||||
|           array([3.74166, 4.24264], dtype=float32) | ||||
|           >>> la.norm(c, ord=1, axis=1) | ||||
|           array([6, 6], dtype=float32) | ||||
|           >>> m = mx.arange(8).reshape(2,2,2) | ||||
|           >>> la.norm(m, axis=(1,2)) | ||||
|           array([3.74166, 11.225], dtype=float32) | ||||
|           >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :]) | ||||
|           (array(3.74166, dtype=float32), array(11.225, dtype=float32)) | ||||
|       )pbdoc"); | ||||
| } | ||||
| @@ -15,6 +15,7 @@ void init_ops(py::module_&); | ||||
| void init_transforms(py::module_&); | ||||
| void init_random(py::module_&); | ||||
| void init_fft(py::module_&); | ||||
| void init_linalg(py::module_&); | ||||
|  | ||||
| PYBIND11_MODULE(core, m) { | ||||
|   m.doc() = "mlx: A framework for machine learning on Apple silicon."; | ||||
| @@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) { | ||||
|   init_transforms(m); | ||||
|   init_random(m); | ||||
|   init_fft(m); | ||||
|   init_linalg(m); | ||||
|   m.attr("__version__") = TOSTRING(_VERSION_); | ||||
| } | ||||
|   | ||||
							
								
								
									
										94
									
								
								python/tests/test_linalg.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								python/tests/test_linalg.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,94 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
|  | ||||
| import itertools | ||||
| import math | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| class TestLinalg(mlx_tests.MLXTestCase): | ||||
|     def test_norm(self): | ||||
|         vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")] | ||||
|         matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] | ||||
|  | ||||
|         for shape in [(3,), (2, 3), (2, 3, 3)]: | ||||
|             x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) | ||||
|             x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) | ||||
|             # Test when at least one axis is provided | ||||
|             for num_axes in range(1, len(shape)): | ||||
|                 if num_axes == 1: | ||||
|                     ords = vector_ords | ||||
|                 else: | ||||
|                     ords = matrix_ords | ||||
|                 for axis in itertools.combinations(range(len(shape)), num_axes): | ||||
|                     for keepdims in [True, False]: | ||||
|                         for o in ords: | ||||
|                             out_np = np.linalg.norm( | ||||
|                                 x_np, ord=o, axis=axis, keepdims=keepdims | ||||
|                             ) | ||||
|                             out_mx = mx.linalg.norm( | ||||
|                                 x_mx, ord=o, axis=axis, keepdims=keepdims | ||||
|                             ) | ||||
|                             with self.subTest( | ||||
|                                 shape=shape, ord=o, axis=axis, keepdims=keepdims | ||||
|                             ): | ||||
|                                 self.assertTrue( | ||||
|                                     np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) | ||||
|                                 ) | ||||
|  | ||||
|         # Test only ord provided | ||||
|         for shape in [(3,), (2, 3)]: | ||||
|             x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) | ||||
|             x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) | ||||
|             for o in [None, 1, -1, float("inf"), -float("inf")]: | ||||
|                 for keepdims in [True, False]: | ||||
|                     out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) | ||||
|                     out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) | ||||
|                     with self.subTest(shape=shape, ord=o, keepdims=keepdims): | ||||
|                         self.assertTrue( | ||||
|                             np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) | ||||
|                         ) | ||||
|  | ||||
|         # Test no ord and no axis provided | ||||
|         for shape in [(3,), (2, 3), (2, 3, 3)]: | ||||
|             x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) | ||||
|             x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) | ||||
|             for keepdims in [True, False]: | ||||
|                 out_np = np.linalg.norm(x_np, keepdims=keepdims) | ||||
|                 out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) | ||||
|                 with self.subTest(shape=shape, keepdims=keepdims): | ||||
|                     self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) | ||||
|  | ||||
|     def test_complex_norm(self): | ||||
|         for shape in [(3,), (2, 3), (2, 3, 3)]: | ||||
|             x_np = np.random.uniform(size=shape).astype( | ||||
|                 np.float32 | ||||
|             ) + 1j * np.random.uniform(size=shape).astype(np.float32) | ||||
|             x_mx = mx.array(x_np) | ||||
|             out_np = np.linalg.norm(x_np) | ||||
|             out_mx = mx.linalg.norm(x_mx) | ||||
|             with self.subTest(shape=shape): | ||||
|                 self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) | ||||
|             for num_axes in range(1, len(shape)): | ||||
|                 for axis in itertools.combinations(range(len(shape)), num_axes): | ||||
|                     out_np = np.linalg.norm(x_np, axis=axis) | ||||
|                     out_mx = mx.linalg.norm(x_mx, axis=axis) | ||||
|                     with self.subTest(shape=shape, axis=axis): | ||||
|                         self.assertTrue( | ||||
|                             np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) | ||||
|                         ) | ||||
|  | ||||
|         x_np = np.random.uniform(size=(4, 4)).astype( | ||||
|             np.float32 | ||||
|         ) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32) | ||||
|         x_mx = mx.array(x_np) | ||||
|         out_np = np.linalg.norm(x_np, ord="fro") | ||||
|         out_mx = mx.linalg.norm(x_mx, ord="fro") | ||||
|         self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
		Reference in New Issue
	
	Block a user