MLX
 
Loading...
Searching...
No Matches
linalg.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <optional>
6
7#include "mlx/array.h"
8#include "mlx/device.h"
9#include "mlx/ops.h"
10#include "mlx/stream.h"
11
13
26 const array& a,
27 const double ord,
28 const std::optional<std::vector<int>>& axis = std::nullopt,
29 bool keepdims = false,
30 StreamOrDevice s = {});
31inline array norm(
32 const array& a,
33 const double ord,
34 int axis,
35 bool keepdims = false,
36 StreamOrDevice s = {}) {
37 return norm(a, ord, std::vector<int>{axis}, keepdims, s);
38}
40 const array& a,
41 const std::string& ord,
42 const std::optional<std::vector<int>>& axis = std::nullopt,
43 bool keepdims = false,
44 StreamOrDevice s = {});
45inline array norm(
46 const array& a,
47 const std::string& ord,
48 int axis,
49 bool keepdims = false,
50 StreamOrDevice s = {}) {
51 return norm(a, ord, std::vector<int>{axis}, keepdims, s);
52}
54 const array& a,
55 const std::optional<std::vector<int>>& axis = std::nullopt,
56 bool keepdims = false,
57 StreamOrDevice s = {});
58inline array
59norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
60 return norm(a, std::vector<int>{axis}, keepdims, s);
61}
62
63std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
64
65std::vector<array> svd(const array& a, StreamOrDevice s = {});
66
67array inv(const array& a, StreamOrDevice s = {});
68
69array tri_inv(const array& a, bool upper = false, StreamOrDevice s = {});
70
71array cholesky(const array& a, bool upper = false, StreamOrDevice s = {});
72
73array pinv(const array& a, StreamOrDevice s = {});
74
75array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});
76
77std::vector<array> lu(const array& a, StreamOrDevice s = {});
78
79std::pair<array, array> lu_factor(const array& a, StreamOrDevice s = {});
80
81array solve(const array& a, const array& b, StreamOrDevice s = {});
82
84 const array& a,
85 const array& b,
86 bool upper = false,
87 StreamOrDevice s = {});
88
93 const array& a,
94 const array& b,
95 int axis = -1,
96 StreamOrDevice s = {});
97
98array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
99
100std::pair<array, array>
101eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {});
102
103} // namespace mlx::core::linalg
Definition array.h:24
Definition linalg.h:12
array eigvalsh(const array &a, std::string UPLO="L", StreamOrDevice s={})
std::pair< array, array > eigh(const array &a, std::string UPLO="L", StreamOrDevice s={})
array cholesky(const array &a, bool upper=false, StreamOrDevice s={})
array solve_triangular(const array &a, const array &b, bool upper=false, StreamOrDevice s={})
std::vector< array > svd(const array &a, StreamOrDevice s={})
array solve(const array &a, const array &b, StreamOrDevice s={})
std::vector< array > lu(const array &a, StreamOrDevice s={})
array tri_inv(const array &a, bool upper=false, StreamOrDevice s={})
array norm(const array &a, const double ord, const std::optional< std::vector< int > > &axis=std::nullopt, bool keepdims=false, StreamOrDevice s={})
Compute vector or matrix norms.
array cross(const array &a, const array &b, int axis=-1, StreamOrDevice s={})
Compute the cross product of two arrays along the given axis.
std::pair< array, array > lu_factor(const array &a, StreamOrDevice s={})
array inv(const array &a, StreamOrDevice s={})
array pinv(const array &a, StreamOrDevice s={})
std::pair< array, array > qr(const array &a, StreamOrDevice s={})
array cholesky_inv(const array &a, bool upper=false, StreamOrDevice s={})
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15