// Copyright © 2023 Apple Inc. #pragma once #include #include "mlx/array.h" #include "mlx/device.h" #include "mlx/ops.h" #include "mlx/stream.h" namespace mlx::core::linalg { /** * Compute vector or matrix norms. * * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). * - If axis is not provided but ord is, then x must be either 1D or 2D. * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm * for matrices) is computed along the given axes. At most 2 axes can be * specified. * - If both axis and ord are provided, then the corresponding matrix or vector * norm is computed. At most 2 axes can be specified. */ array norm( const array& a, const double ord, const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); inline array norm( const array& a, const double ord, int axis, bool keepdims = false, StreamOrDevice s = {}) { return norm(a, ord, std::vector{axis}, keepdims, s); } array norm( const array& a, const std::string& ord, const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); inline array norm( const array& a, const std::string& ord, int axis, bool keepdims = false, StreamOrDevice s = {}) { return norm(a, ord, std::vector{axis}, keepdims, s); } array norm( const array& a, const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); inline array norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { return norm(a, std::vector{axis}, keepdims, s); } std::pair qr(const array& a, StreamOrDevice s = {}); std::vector svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */); inline std::vector svd(const array& a, StreamOrDevice s = {}) { return svd(a, true, s); } array inv(const array& a, StreamOrDevice s = {}); array tri_inv(const array& a, bool upper = false, StreamOrDevice s = {}); array cholesky(const array& a, bool upper = false, StreamOrDevice s = {}); array pinv(const array& a, StreamOrDevice s = {}); array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); std::vector lu(const array& a, StreamOrDevice s = {}); std::pair lu_factor(const array& a, StreamOrDevice s = {}); array solve(const array& a, const array& b, StreamOrDevice s = {}); array solve_triangular( const array& a, const array& b, bool upper = false, StreamOrDevice s = {}); /** * Compute the cross product of two arrays along the given axis. */ array cross( const array& a, const array& b, int axis = -1, StreamOrDevice s = {}); array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); std::pair eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); } // namespace mlx::core::linalg