MLX
 
Loading...
Searching...
No Matches
linalg.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <optional>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/ops.h"
10#include "mlx/stream.h"
11
13
26 const array& a,
27 const double ord,
28 const std::optional<std::vector<int>>& axis = std::nullopt,
29 bool keepdims = false,
30 StreamOrDevice s = {});
31inline array norm(
32 const array& a,
33 const double ord,
34 int axis,
35 bool keepdims = false,
36 StreamOrDevice s = {}) {
37 return norm(a, ord, std::vector<int>{axis}, keepdims, s);
38}
40 const array& a,
41 const std::string& ord,
42 const std::optional<std::vector<int>>& axis = std::nullopt,
43 bool keepdims = false,
44 StreamOrDevice s = {});
45inline array norm(
46 const array& a,
47 const std::string& ord,
48 int axis,
49 bool keepdims = false,
50 StreamOrDevice s = {}) {
51 return norm(a, ord, std::vector<int>{axis}, keepdims, s);
52}
54 const array& a,
55 const std::optional<std::vector<int>>& axis = std::nullopt,
56 bool keepdims = false,
57 StreamOrDevice s = {});
58inline array
59norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
60 return norm(a, std::vector<int>{axis}, keepdims, s);
61}
62
63std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
64
65std::vector<array> svd(const array& a, StreamOrDevice s = {});
66
67array inv(const array& a, StreamOrDevice s = {});
68
69array tri_inv(const array& a, bool upper = false, StreamOrDevice s = {});
70
71array cholesky(const array& a, bool upper = false, StreamOrDevice s = {});
72
73array pinv(const array& a, StreamOrDevice s = {});
74
75array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
76
81 const array& a,
82 const array& b,
83 int axis = -1,
84 StreamOrDevice s = {});
85
86array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
87
88std::pair<array, array>
89eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
90
91} // namespace mlx::core::linalg
Definition array.h:24
Definition linalg.h:12
array eigvalsh(const array &a, std::string UPLO="L", StreamOrDevice s={})
std::pair< array, array > eigh(const array &a, std::string UPLO="L", StreamOrDevice s={})
array cholesky(const array &a, bool upper=false, StreamOrDevice s={})
std::vector< array > svd(const array &a, StreamOrDevice s={})
array tri_inv(const array &a, bool upper=false, StreamOrDevice s={})
array norm(const array &a, const double ord, const std::optional< std::vector< int > > &axis=std::nullopt, bool keepdims=false, StreamOrDevice s={})
Compute vector or matrix norms.
array cross(const array &a, const array &b, int axis=-1, StreamOrDevice s={})
Compute the cross product of two arrays along the given axis.
array inv(const array &a, StreamOrDevice s={})
array pinv(const array &a, StreamOrDevice s={})
std::pair< array, array > qr(const array &a, StreamOrDevice s={})
array cholesky_inv(const array &a, bool upper=false, StreamOrDevice s={})
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15