mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	linalg.norm (#187)
* implemented vector_norm in cpp added linalg to mlx * implemented vector_norm python binding * renamed vector_norm to norm, implemented norm without provided ord * completed the implementation of the norm * added tests * removed unused import in linalg.cpp * updated python bindings * added some tests for python bindings * handling inf, -inf as numpy does, more extensive tests of compatibility with numpy * added better docs and examples * refactored mlx.linalg.norm bindings * reused existing util for implementation of linalg.norm * more tests * fixed a bug with no ord and axis provided * removed unused imports * some style and API consistency updates to linalg norm * remove unused includes * fix python tests * fixed a bug with frobenius norm of a complex-valued matrix * complex for vector too --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							447bc089b9
						
					
				
				
					commit
					6b0d30bb85
				
			@@ -11,6 +11,7 @@ pybind11_add_module(
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp
 | 
			
		||||
  ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										180
									
								
								python/src/linalg.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								python/src/linalg.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,180 @@
 | 
			
		||||
// Copyright © 2023 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <variant>
 | 
			
		||||
 | 
			
		||||
#include <pybind11/pybind11.h>
 | 
			
		||||
#include <pybind11/stl.h>
 | 
			
		||||
 | 
			
		||||
#include "mlx/linalg.h"
 | 
			
		||||
 | 
			
		||||
#include "python/src/load.h"
 | 
			
		||||
#include "python/src/utils.h"
 | 
			
		||||
 | 
			
		||||
namespace py = pybind11;
 | 
			
		||||
using namespace py::literals;
 | 
			
		||||
 | 
			
		||||
using namespace mlx::core;
 | 
			
		||||
using namespace mlx::core::linalg;
 | 
			
		||||
 | 
			
		||||
void init_linalg(py::module_& parent_module) {
 | 
			
		||||
  py::options options;
 | 
			
		||||
  options.disable_function_signatures();
 | 
			
		||||
 | 
			
		||||
  auto m = parent_module.def_submodule(
 | 
			
		||||
      "linalg", "mlx.core.linalg: linear algebra routines.");
 | 
			
		||||
 | 
			
		||||
  m.def(
 | 
			
		||||
      "norm",
 | 
			
		||||
      [](const array& a,
 | 
			
		||||
         const std::variant<std::monostate, int, double, std::string>& ord_,
 | 
			
		||||
         const std::variant<std::monostate, int, std::vector<int>>& axis_,
 | 
			
		||||
         const bool keepdims,
 | 
			
		||||
         const StreamOrDevice stream) {
 | 
			
		||||
        std::optional<std::vector<int>> axis = std::nullopt;
 | 
			
		||||
        if (auto pv = std::get_if<int>(&axis_); pv) {
 | 
			
		||||
          axis = std::vector<int>{*pv};
 | 
			
		||||
        } else if (auto pv = std::get_if<std::vector<int>>(&axis_); pv) {
 | 
			
		||||
          axis = *pv;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (std::holds_alternative<std::monostate>(ord_)) {
 | 
			
		||||
          return norm(a, axis, keepdims, stream);
 | 
			
		||||
        } else {
 | 
			
		||||
          if (auto pv = std::get_if<std::string>(&ord_); pv) {
 | 
			
		||||
            return norm(a, *pv, axis, keepdims, stream);
 | 
			
		||||
          }
 | 
			
		||||
          double ord;
 | 
			
		||||
          if (auto pv = std::get_if<int>(&ord_); pv) {
 | 
			
		||||
            ord = *pv;
 | 
			
		||||
          } else {
 | 
			
		||||
            ord = std::get<double>(ord_);
 | 
			
		||||
          }
 | 
			
		||||
          return norm(a, ord, axis, keepdims, stream);
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      "a"_a,
 | 
			
		||||
      py::pos_only(),
 | 
			
		||||
      "ord"_a = none,
 | 
			
		||||
      "axis"_a = none,
 | 
			
		||||
      "keepdims"_a = false,
 | 
			
		||||
      py::kw_only(),
 | 
			
		||||
      "stream"_a = none,
 | 
			
		||||
      R"pbdoc(
 | 
			
		||||
        norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array
 | 
			
		||||
 | 
			
		||||
        Matrix or vector norm.
 | 
			
		||||
 | 
			
		||||
        This function computes vector or  matrix norms depending on the value of
 | 
			
		||||
        the ``ord`` and ``axis`` parameters.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          a (array): Input array.  If ``axis`` is ``None``, ``a`` must be 1-D or 2-D,
 | 
			
		||||
            unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the
 | 
			
		||||
            2-norm of ``a.flatten`` will be returned.
 | 
			
		||||
          ord (scalar or str, optional): Order of the norm (see table under ``Notes``).
 | 
			
		||||
            If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed
 | 
			
		||||
            along the given ``axis``.  Default: ``None``.
 | 
			
		||||
          axis (int or list(int), optional): If ``axis`` is an integer, it specifies the
 | 
			
		||||
            axis of ``a`` along which to compute the vector norms.  If ``axis`` is a
 | 
			
		||||
            2-tuple, it specifies the axes that hold 2-D matrices, and the matrix
 | 
			
		||||
            norms of these matrices are computed. If `axis` is ``None`` then
 | 
			
		||||
            either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is
 | 
			
		||||
            2-D) is returned. Default: ``None``.
 | 
			
		||||
          keepdims (bool, optional): If ``True``, the axes which are normed over are
 | 
			
		||||
            left in the result as dimensions with size one. Default ``False``.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
          array: The output containing the norm(s).
 | 
			
		||||
 | 
			
		||||
        Notes:
 | 
			
		||||
          For values of ``ord < 1``, the result is, strictly speaking, not a
 | 
			
		||||
          mathematical norm, but it may still be useful for various numerical
 | 
			
		||||
          purposes.
 | 
			
		||||
 | 
			
		||||
          The following norms can be calculated:
 | 
			
		||||
 | 
			
		||||
          =====  ============================  ==========================
 | 
			
		||||
          ord    norm for matrices             norm for vectors
 | 
			
		||||
          =====  ============================  ==========================
 | 
			
		||||
          None   Frobenius norm                2-norm
 | 
			
		||||
          'fro'  Frobenius norm                --
 | 
			
		||||
          inf    max(sum(abs(x), axis=1))      max(abs(x))
 | 
			
		||||
          -inf   min(sum(abs(x), axis=1))      min(abs(x))
 | 
			
		||||
          0      --                            sum(x != 0)
 | 
			
		||||
          1      max(sum(abs(x), axis=0))      as below
 | 
			
		||||
          -1     min(sum(abs(x), axis=0))      as below
 | 
			
		||||
          2      2-norm (largest sing. value)  as below
 | 
			
		||||
          -2     smallest singular value       as below
 | 
			
		||||
          other  --                            sum(abs(x)**ord)**(1./ord)
 | 
			
		||||
          =====  ============================  ==========================
 | 
			
		||||
 | 
			
		||||
          .. warning::
 | 
			
		||||
            Nuclear norm and norms based on singular values are not yet implemented.
 | 
			
		||||
 | 
			
		||||
          The Frobenius norm is given by [1]_:
 | 
			
		||||
 | 
			
		||||
              :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}`
 | 
			
		||||
 | 
			
		||||
          The nuclear norm is the sum of the singular values.
 | 
			
		||||
 | 
			
		||||
          Both the Frobenius and nuclear norm orders are only defined for
 | 
			
		||||
          matrices and raise a ``ValueError`` when ``a.ndim != 2``.
 | 
			
		||||
 | 
			
		||||
        References:
 | 
			
		||||
          .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*,
 | 
			
		||||
                 Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
 | 
			
		||||
 | 
			
		||||
        Examples:
 | 
			
		||||
          >>> import mlx.core as mx
 | 
			
		||||
          >>> from mlx.core import linalg as la
 | 
			
		||||
          >>> a = mx.arange(9) - 4
 | 
			
		||||
          >>> a
 | 
			
		||||
          array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
 | 
			
		||||
          >>> b = a.reshape((3,3))
 | 
			
		||||
          >>> b
 | 
			
		||||
          array([[-4, -3, -2],
 | 
			
		||||
                 [-1,  0,  1],
 | 
			
		||||
                 [ 2,  3,  4]], dtype=int32)
 | 
			
		||||
          >>> la.norm(a)
 | 
			
		||||
          array(7.74597, dtype=float32)
 | 
			
		||||
          >>> la.norm(b)
 | 
			
		||||
          array(7.74597, dtype=float32)
 | 
			
		||||
          >>> la.norm(b, 'fro')
 | 
			
		||||
          array(7.74597, dtype=float32)
 | 
			
		||||
          >>> la.norm(a, float("inf"))
 | 
			
		||||
          array(4, dtype=float32)
 | 
			
		||||
          >>> la.norm(b, float("inf"))
 | 
			
		||||
          array(9, dtype=float32)
 | 
			
		||||
          >>> la.norm(a, -float("inf"))
 | 
			
		||||
          array(0, dtype=float32)
 | 
			
		||||
          >>> la.norm(b, -float("inf"))
 | 
			
		||||
          array(2, dtype=float32)
 | 
			
		||||
          >>> la.norm(a, 1)
 | 
			
		||||
          array(20, dtype=float32)
 | 
			
		||||
          >>> la.norm(b, 1)
 | 
			
		||||
          array(7, dtype=float32)
 | 
			
		||||
          >>> la.norm(a, -1)
 | 
			
		||||
          array(0, dtype=float32)
 | 
			
		||||
          >>> la.norm(b, -1)
 | 
			
		||||
          array(6, dtype=float32)
 | 
			
		||||
          >>> la.norm(a, 2)
 | 
			
		||||
          array(7.74597, dtype=float32)
 | 
			
		||||
          >>> la.norm(a, 3)
 | 
			
		||||
          array(5.84804, dtype=float32)
 | 
			
		||||
          >>> la.norm(a, -3)
 | 
			
		||||
          array(0, dtype=float32)
 | 
			
		||||
          >>> c = mx.array([[ 1, 2, 3],
 | 
			
		||||
          ...               [-1, 1, 4]])
 | 
			
		||||
          >>> la.norm(c, axis=0)
 | 
			
		||||
          array([1.41421, 2.23607, 5], dtype=float32)
 | 
			
		||||
          >>> la.norm(c, axis=1)
 | 
			
		||||
          array([3.74166, 4.24264], dtype=float32)
 | 
			
		||||
          >>> la.norm(c, ord=1, axis=1)
 | 
			
		||||
          array([6, 6], dtype=float32)
 | 
			
		||||
          >>> m = mx.arange(8).reshape(2,2,2)
 | 
			
		||||
          >>> la.norm(m, axis=(1,2))
 | 
			
		||||
          array([3.74166, 11.225], dtype=float32)
 | 
			
		||||
          >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
 | 
			
		||||
          (array(3.74166, dtype=float32), array(11.225, dtype=float32))
 | 
			
		||||
      )pbdoc");
 | 
			
		||||
}
 | 
			
		||||
@@ -15,6 +15,7 @@ void init_ops(py::module_&);
 | 
			
		||||
void init_transforms(py::module_&);
 | 
			
		||||
void init_random(py::module_&);
 | 
			
		||||
void init_fft(py::module_&);
 | 
			
		||||
void init_linalg(py::module_&);
 | 
			
		||||
 | 
			
		||||
PYBIND11_MODULE(core, m) {
 | 
			
		||||
  m.doc() = "mlx: A framework for machine learning on Apple silicon.";
 | 
			
		||||
@@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) {
 | 
			
		||||
  init_transforms(m);
 | 
			
		||||
  init_random(m);
 | 
			
		||||
  init_fft(m);
 | 
			
		||||
  init_linalg(m);
 | 
			
		||||
  m.attr("__version__") = TOSTRING(_VERSION_);
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user