mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Added bindings and removed python implementation
This commit is contained in:
@@ -459,4 +459,100 @@ void init_fft(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The real array containing the inverse of :func:`rfftn`.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"stft",
|
||||
[](const mx::array& x,
|
||||
int n_fft = 2048,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const mx::array& window = mx::array(),
|
||||
bool center = true,
|
||||
const std::string& pad_mode = "reflect",
|
||||
bool normalized = false,
|
||||
bool onesided = true,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
return mx::stft::stft(
|
||||
x,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
pad_mode,
|
||||
normalized,
|
||||
onesided,
|
||||
s);
|
||||
},
|
||||
"x"_a,
|
||||
"n_fft"_a = 2048,
|
||||
"hop_length"_a = -1,
|
||||
"win_length"_a = -1,
|
||||
"window"_a = mx::array(),
|
||||
"center"_a = true,
|
||||
"pad_mode"_a = "reflect",
|
||||
"normalized"_a = false,
|
||||
"onesided"_a = true,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Short-time Fourier Transform (STFT).
|
||||
|
||||
Args:
|
||||
x (array): Input signal.
|
||||
n_fft (int, optional): Number of FFT points. Default is 2048.
|
||||
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
|
||||
win_length (int, optional): Window size. Default is `n_fft`.
|
||||
window (array, optional): Window function. Default is a rectangular window.
|
||||
center (bool, optional): Whether to pad the signal to center the frames. Default is True.
|
||||
pad_mode (str, optional): Padding mode. Default is "reflect".
|
||||
normalized (bool, optional): Whether to normalize the STFT. Default is False.
|
||||
onesided (bool, optional): Whether to return a one-sided STFT. Default is True.
|
||||
|
||||
Returns:
|
||||
array: The STFT of the input signal.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"istft",
|
||||
[](const mx::array& stft_matrix,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const mx::array& window = mx::array(),
|
||||
bool center = true,
|
||||
int length = -1,
|
||||
bool normalized = false,
|
||||
mx::StreamOrDevice s = {}) {
|
||||
return mx::stft::istft(
|
||||
stft_matrix,
|
||||
hop_length,
|
||||
win_length,
|
||||
window,
|
||||
center,
|
||||
length,
|
||||
normalized,
|
||||
s);
|
||||
},
|
||||
"stft_matrix"_a,
|
||||
"hop_length"_a = -1,
|
||||
"win_length"_a = -1,
|
||||
"window"_a = mx::array(),
|
||||
"center"_a = true,
|
||||
"length"_a = -1,
|
||||
"normalized"_a = false,
|
||||
"stream"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Inverse Short-time Fourier Transform (ISTFT).
|
||||
|
||||
Args:
|
||||
stft_matrix (array): Input STFT matrix.
|
||||
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
|
||||
win_length (int, optional): Window size. Default is `n_fft`.
|
||||
window (array, optional): Window function. Default is a rectangular window.
|
||||
center (bool, optional): Whether the signal was padded to center the frames. Default is True.
|
||||
length (int, optional): Length of the output signal. Default is inferred from the STFT matrix.
|
||||
normalized (bool, optional): Whether the STFT was normalized. Default is False.
|
||||
|
||||
Returns:
|
||||
array: The reconstructed signal.
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user