mlx/mlx/linalg.h
Kashif Rasul 3ddc07e936
Eigenvalues and eigenvectors (#1334)
* initial eigvalsh

* add compute_vectors

* add compute_vectors_

* return a pair

* add eigh to return only eigenvectors

* fixed typo

* merge merge Eighvalsh and Eigh into a single primitive

* use the same primate with the flag

* fix primatives

* use MULTI

* fix eval_gpu

* fix decleration

* rename EighPrimitive to Eigh

* tests

* tests

* fix rebase and format

* cleanup lapack

* format

* add cblas.h

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-10-22 12:18:48 -07:00

92 lines
2.5 KiB
C++

// Copyright © 2023 Apple Inc.
#pragma once
#include <optional>
#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/ops.h"
#include "mlx/stream.h"
namespace mlx::core::linalg {
/**
* Compute vector or matrix norms.
*
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
* - If axis is not provided but ord is, then x must be either 1D or 2D.
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
* for matrices) is computed along the given axes. At most 2 axes can be
* specified.
* - If both axis and ord are provided, then the corresponding matrix or vector
* norm is computed. At most 2 axes can be specified.
*/
array norm(
const array& a,
const double ord,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array norm(
const array& a,
const double ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm(
const array& a,
const std::string& ord,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array norm(
const array& a,
const std::string& ord,
int axis,
bool keepdims = false,
StreamOrDevice s = {}) {
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
}
array norm(
const array& a,
const std::optional<std::vector<int>>& axis = std::nullopt,
bool keepdims = false,
StreamOrDevice s = {});
inline array
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
return norm(a, std::vector<int>{axis}, keepdims, s);
}
std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
std::vector<array> svd(const array& a, StreamOrDevice s = {});
array inv(const array& a, StreamOrDevice s = {});
array tri_inv(const array& a, bool upper = false, StreamOrDevice s = {});
array cholesky(const array& a, bool upper = false, StreamOrDevice s = {});
array pinv(const array& a, StreamOrDevice s = {});
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
/**
* Compute the cross product of two arrays along the given axis.
*/
array cross(
const array& a,
const array& b,
int axis = -1,
StreamOrDevice s = {});
array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
std::pair<array, array>
eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
} // namespace mlx::core::linalg