MLX
 
Loading...
Searching...
No Matches
fft.h
Go to the documentation of this file.
1// Copyright © 2023 Apple Inc.
2
3#pragma once
4
5#include <variant>
6
7#include "array.h"
8#include "device.h"
9#include "utils.h"
10
11namespace mlx::core::fft {
12
15 const array& a,
16 const Shape& n,
17 const std::vector<int>& axes,
18 StreamOrDevice s = {});
19array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
20array fftn(const array& a, StreamOrDevice s = {});
21
24 const array& a,
25 const Shape& n,
26 const std::vector<int>& axes,
27 StreamOrDevice s = {});
29 const array& a,
30 const std::vector<int>& axes,
31 StreamOrDevice s = {});
32array ifftn(const array& a, StreamOrDevice s = {});
33
35inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
36 return fftn(a, {n}, {axis}, s);
37}
38inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
39 return fftn(a, {axis}, s);
40}
41
43inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
44 return ifftn(a, {n}, {axis}, s);
45}
46inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
47 return ifftn(a, {axis}, s);
48}
49
51inline array fft2(
52 const array& a,
53 const Shape& n,
54 const std::vector<int>& axes,
55 StreamOrDevice s = {}) {
56 return fftn(a, n, axes, s);
57}
58inline array fft2(
59 const array& a,
60 const std::vector<int>& axes = {-2, -1},
61 StreamOrDevice s = {}) {
62 return fftn(a, axes, s);
63}
64
66inline array ifft2(
67 const array& a,
68 const Shape& n,
69 const std::vector<int>& axes,
70 StreamOrDevice s = {}) {
71 return ifftn(a, n, axes, s);
72}
73inline array ifft2(
74 const array& a,
75 const std::vector<int>& axes = {-2, -1},
76 StreamOrDevice s = {}) {
77 return ifftn(a, axes, s);
78}
79
82 const array& a,
83 const Shape& n,
84 const std::vector<int>& axes,
85 StreamOrDevice s = {});
87 const array& a,
88 const std::vector<int>& axes,
89 StreamOrDevice s = {});
90array rfftn(const array& a, StreamOrDevice s = {});
91
94 const array& a,
95 const Shape& n,
96 const std::vector<int>& axes,
97 StreamOrDevice s = {});
99 const array& a,
100 const std::vector<int>& axes,
101 StreamOrDevice s = {});
103
105inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
106 return rfftn(a, {n}, {axis}, s);
107}
108inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
109 return rfftn(a, {axis}, s);
110}
111
112inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
113 return irfftn(a, {n}, {axis}, s);
114}
115inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
116 return irfftn(a, {axis}, s);
117}
118
121 const array& a,
122 const Shape& n,
123 const std::vector<int>& axes,
124 StreamOrDevice s = {}) {
125 return rfftn(a, n, axes, s);
126}
128 const array& a,
129 const std::vector<int>& axes = {-2, -1},
130 StreamOrDevice s = {}) {
131 return rfftn(a, axes, s);
132}
133
136 const array& a,
137 const Shape& n,
138 const std::vector<int>& axes,
139 StreamOrDevice s = {}) {
140 return irfftn(a, n, axes, s);
141}
143 const array& a,
144 const std::vector<int>& axes = {-2, -1},
145 StreamOrDevice s = {}) {
146 return irfftn(a, axes, s);
147}
148
149} // namespace mlx::core::fft
Definition array.h:24
Definition fft.h:11
array fftn(const array &a, const Shape &n, const std::vector< int > &axes, StreamOrDevice s={})
Compute the n-dimensional Fourier Transform.
array irfftn(const array &a, const Shape &n, const std::vector< int > &axes, StreamOrDevice s={})
Compute the n-dimensional inverse of rfftn.
array fft2(const array &a, const Shape &n, const std::vector< int > &axes, StreamOrDevice s={})
Compute the two-dimensional Fourier Transform.
Definition fft.h:51
array ifft(const array &a, int n, int axis, StreamOrDevice s={})
Compute the one-dimensional inverse Fourier Transform.
Definition fft.h:43
array rfft2(const array &a, const Shape &n, const std::vector< int > &axes, StreamOrDevice s={})
Compute the two-dimensional Fourier Transform on a real input.
Definition fft.h:120
array ifft2(const array &a, const Shape &n, const std::vector< int > &axes, StreamOrDevice s={})
Compute the two-dimensional inverse Fourier Transform.
Definition fft.h:66
array rfft(const array &a, int n, int axis, StreamOrDevice s={})
Compute the one-dimensional Fourier Transform on a real input.
Definition fft.h:105
array irfft(const array &a, int n, int axis, StreamOrDevice s={})
Compute the one-dimensional inverse of rfft.
Definition fft.h:112
array rfftn(const array &a, const Shape &n, const std::vector< int > &axes, StreamOrDevice s={})
Compute the n-dimensional Fourier Transform on a real input.
array fft(const array &a, int n, int axis, StreamOrDevice s={})
Compute the one-dimensional Fourier Transform.
Definition fft.h:35
array irfft2(const array &a, const Shape &n, const std::vector< int > &axes, StreamOrDevice s={})
Compute the two-dimensional inverse of rfft2.
Definition fft.h:135
array ifftn(const array &a, const Shape &n, const std::vector< int > &axes, StreamOrDevice s={})
Compute the n-dimensional inverse Fourier Transform.
std::vector< ShapeElem > Shape
Definition array.h:21
std::variant< std::monostate, Stream, Device > StreamOrDevice
Definition utils.h:15