From b4f01a8f7dbeabb7406298a345e6dd4869669818 Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Thu, 17 Apr 2025 22:27:46 +0530 Subject: [PATCH] Added cpp implementation of stft --- mlx/fft.cpp | 132 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/mlx/fft.cpp b/mlx/fft.cpp index f0d41bf0f..dfb585dfc 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -1,7 +1,9 @@ // Copyright © 2023 Apple Inc. +#include #include #include +#include #include "mlx/fft.h" #include "mlx/ops.h" @@ -190,3 +192,133 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) { } } // namespace mlx::core::fft + +namespace mlx::core::stft { + +array stft( + const array& x, + int n_fft = 2048, + int hop_length = -1, + int win_length = -1, + const array& window, + bool center = true, + const std::string& pad_mode = "reflect", + bool normalized = false, + bool onesided = true, + StreamOrDevice s /* = {} */) { + if (hop_length == -1) + hop_length = n_fft / 4; + if (win_length == -1) + win_length = n_fft; + + array win = (window.size() == 0) ? ones({win_length}, float32, s) : window; + + if (win_length < n_fft) { + int pad_left = (n_fft - win_length) / 2; + int pad_right = n_fft - win_length - pad_left; + + array left_pad = zeros({pad_left}, float32, s); + array right_pad = zeros({pad_right}, float32, s); + win = concatenate({left_pad, win, right_pad}, 0, s); + } + + array padded_x = x; + if (center) { + int pad_width = n_fft / 2; + + array left_pad = zeros({pad_width}, x.dtype(), s); + array right_pad = zeros({pad_width}, x.dtype(), s); + padded_x = concatenate({left_pad, padded_x, right_pad}, 0, s); + } + + int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length; + + std::vector frames; + for (int i = 0; i < n_frames; ++i) { + array frame = + slice(padded_x, {i * hop_length}, {i * hop_length + n_fft}, s); + frames.push_back(multiply(frame, win, s)); + } + + array stacked_frames = stack(frames, 0, s); + + array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s); + + if (normalized) { + array n_fft_array = full({1}, static_cast(n_fft), float32, s); + stft_result = divide(stft_result, sqrt(n_fft_array, s), s); + } + + if (onesided) { + stft_result = slice(stft_result, {}, {n_fft / 2 + 1}, s); + } + + return stft_result; +} + +array istft( + const array& stft_matrix, + int hop_length = -1, + int win_length = -1, + const array& window, + bool center = true, + int length = -1, + bool normalized = false, + StreamOrDevice s /* = {} */) { + int n_fft = (stft_matrix.shape(-1) - 1) * 2; + if (hop_length == -1) + hop_length = n_fft / 4; + if (win_length == -1) + win_length = n_fft; + + array win = (window.size() == 0) ? ones({win_length}, float32, s) : window; + + if (win_length < n_fft) { + int pad_left = (n_fft - win_length) / 2; + int pad_right = n_fft - win_length - pad_left; + + array left_pad = zeros({pad_left}, float32, s); + array right_pad = zeros({pad_right}, float32, s); + win = concatenate({left_pad, win, right_pad}, 0, s); + } + + array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s); + + frames = multiply(frames, win, s); + + int signal_length = (frames.shape(0) - 1) * hop_length + n_fft; + array signal = zeros({signal_length}, float32, s); + array window_sum = zeros({signal_length}, float32, s); + + for (int i = 0; i < frames.shape(0); ++i) { + array frame = reshape(slice(frames, {i}, {i + 1}, s), {n_fft}, s); + array signal_slice = + slice(signal, {i * hop_length}, {i * hop_length + n_fft}, s); + array window_slice = + slice(window_sum, {i * hop_length}, {i * hop_length + n_fft}, s); + + signal_slice = add(signal_slice, frame, s); + window_slice = add(window_slice, win, s); + } + + signal = divide(signal, window_sum, s); + + if (center) { + int pad_width = n_fft / 2; + signal = slice(signal, {pad_width}, {signal.shape(0) - pad_width}, s); + } + + if (length > 0) { + if (signal.shape(0) > length) { + signal = slice(signal, {0}, {length}, s); + } else if (signal.shape(0) < length) { + int pad_length = length - signal.shape(0); + array pad_array = zeros({pad_length}, signal.dtype(), s); + signal = concatenate({signal, pad_array}, 0, s); + } + } + + return signal; +} + +} // namespace mlx::core::stft