Added bindings and removed python implementation

This commit is contained in:
paramthakkar123
2025-04-18 01:45:06 +05:30
parent b4f01a8f7d
commit faf76a8133
2 changed files with 96 additions and 151 deletions

View File

@@ -459,4 +459,100 @@ void init_fft(nb::module_& parent_module) {
Returns:
array: The real array containing the inverse of :func:`rfftn`.
)pbdoc");
m.def(
"stft",
[](const mx::array& x,
int n_fft = 2048,
int hop_length = -1,
int win_length = -1,
const mx::array& window = mx::array(),
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
mx::StreamOrDevice s = {}) {
return mx::stft::stft(
x,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
s);
},
"x"_a,
"n_fft"_a = 2048,
"hop_length"_a = -1,
"win_length"_a = -1,
"window"_a = mx::array(),
"center"_a = true,
"pad_mode"_a = "reflect",
"normalized"_a = false,
"onesided"_a = true,
"stream"_a = nb::none(),
R"pbdoc(
Short-time Fourier Transform (STFT).
Args:
x (array): Input signal.
n_fft (int, optional): Number of FFT points. Default is 2048.
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
win_length (int, optional): Window size. Default is `n_fft`.
window (array, optional): Window function. Default is a rectangular window.
center (bool, optional): Whether to pad the signal to center the frames. Default is True.
pad_mode (str, optional): Padding mode. Default is "reflect".
normalized (bool, optional): Whether to normalize the STFT. Default is False.
onesided (bool, optional): Whether to return a one-sided STFT. Default is True.
Returns:
array: The STFT of the input signal.
)pbdoc");
m.def(
"istft",
[](const mx::array& stft_matrix,
int hop_length = -1,
int win_length = -1,
const mx::array& window = mx::array(),
bool center = true,
int length = -1,
bool normalized = false,
mx::StreamOrDevice s = {}) {
return mx::stft::istft(
stft_matrix,
hop_length,
win_length,
window,
center,
length,
normalized,
s);
},
"stft_matrix"_a,
"hop_length"_a = -1,
"win_length"_a = -1,
"window"_a = mx::array(),
"center"_a = true,
"length"_a = -1,
"normalized"_a = false,
"stream"_a = nb::none(),
R"pbdoc(
Inverse Short-time Fourier Transform (ISTFT).
Args:
stft_matrix (array): Input STFT matrix.
hop_length (int, optional): Number of samples between successive frames. Default is `n_fft // 4`.
win_length (int, optional): Window size. Default is `n_fft`.
window (array, optional): Window function. Default is a rectangular window.
center (bool, optional): Whether the signal was padded to center the frames. Default is True.
length (int, optional): Length of the output signal. Default is inferred from the STFT matrix.
normalized (bool, optional): Whether the STFT was normalized. Default is False.
Returns:
array: The reconstructed signal.
)pbdoc");
}