MLX
Loading...
Searching...
No Matches
lapack.h
Go to the documentation of this file.
1
// Copyright © 2023-2024 Apple Inc.
2
3
#pragma once
4
5
// Required for Visual Studio.
6
// https://github.com/OpenMathLib/OpenBLAS/blob/develop/docs/install.md
7
#ifdef _MSC_VER
8
#include <complex>
9
#define LAPACK_COMPLEX_CUSTOM
10
#define lapack_complex_float std::complex<float>
11
#define lapack_complex_double std::complex<double>
12
#endif
13
14
#ifdef MLX_USE_ACCELERATE
15
#include <Accelerate/Accelerate.h>
16
#else
17
#include <cblas.h>
18
#include <
lapack.h
>
19
#endif
20
21
#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
22
23
// This is to work around a change in the function signatures of lapack >= 3.9.1
24
// where functions taking char* also include a strlen argument, see a similar
25
// change in OpenCV:
26
// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
27
#define MLX_LAPACK_FUNC(f) LAPACK_##f
28
29
#else
30
31
#define MLX_LAPACK_FUNC(f) f##_
32
33
#endif
34
35
#define INSTANTIATE_LAPACK_TYPES(FUNC) \
36
template <typename T, typename... Args> \
37
void FUNC(Args... args) { \
38
if constexpr (std::is_same_v<T, float>) { \
39
MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
40
} else if constexpr (std::is_same_v<T, double>) { \
41
MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
42
} \
43
}
44
45
INSTANTIATE_LAPACK_TYPES
(
geqrf
)
46
INSTANTIATE_LAPACK_TYPES
(
orgqr
)
47
INSTANTIATE_LAPACK_TYPES
(
syevd
)
48
INSTANTIATE_LAPACK_TYPES
(
potrf
)
49
INSTANTIATE_LAPACK_TYPES
(
gesvdx
)
50
INSTANTIATE_LAPACK_TYPES
(
getrf
)
51
INSTANTIATE_LAPACK_TYPES
(
getri
)
52
INSTANTIATE_LAPACK_TYPES
(
trtri
)
lapack.h
syevd
void syevd(Args... args)
Definition
lapack.h:47
getri
void getri(Args... args)
Definition
lapack.h:51
getrf
void getrf(Args... args)
Definition
lapack.h:50
trtri
void trtri(Args... args)
Definition
lapack.h:52
INSTANTIATE_LAPACK_TYPES
#define INSTANTIATE_LAPACK_TYPES(FUNC)
Definition
lapack.h:35
potrf
void potrf(Args... args)
Definition
lapack.h:48
orgqr
void orgqr(Args... args)
Definition
lapack.h:46
geqrf
void geqrf(Args... args)
Definition
lapack.h:45
gesvdx
void gesvdx(Args... args)
Definition
lapack.h:49
mlx
backend
cpu
lapack.h
Generated by
1.13.2