mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Added cpp implementation of stft
This commit is contained in:
parent
91097e6179
commit
b4f01a8f7d
132
mlx/fft.cpp
132
mlx/fft.cpp
@ -1,7 +1,9 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/fft.h"
|
||||
#include "mlx/ops.h"
|
||||
@ -190,3 +192,133 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) {
|
||||
}
|
||||
|
||||
} // namespace mlx::core::fft
|
||||
|
||||
namespace mlx::core::stft {
|
||||
|
||||
array stft(
|
||||
const array& x,
|
||||
int n_fft = 2048,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const array& window,
|
||||
bool center = true,
|
||||
const std::string& pad_mode = "reflect",
|
||||
bool normalized = false,
|
||||
bool onesided = true,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (hop_length == -1)
|
||||
hop_length = n_fft / 4;
|
||||
if (win_length == -1)
|
||||
win_length = n_fft;
|
||||
|
||||
array win = (window.size() == 0) ? ones({win_length}, float32, s) : window;
|
||||
|
||||
if (win_length < n_fft) {
|
||||
int pad_left = (n_fft - win_length) / 2;
|
||||
int pad_right = n_fft - win_length - pad_left;
|
||||
|
||||
array left_pad = zeros({pad_left}, float32, s);
|
||||
array right_pad = zeros({pad_right}, float32, s);
|
||||
win = concatenate({left_pad, win, right_pad}, 0, s);
|
||||
}
|
||||
|
||||
array padded_x = x;
|
||||
if (center) {
|
||||
int pad_width = n_fft / 2;
|
||||
|
||||
array left_pad = zeros({pad_width}, x.dtype(), s);
|
||||
array right_pad = zeros({pad_width}, x.dtype(), s);
|
||||
padded_x = concatenate({left_pad, padded_x, right_pad}, 0, s);
|
||||
}
|
||||
|
||||
int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length;
|
||||
|
||||
std::vector<array> frames;
|
||||
for (int i = 0; i < n_frames; ++i) {
|
||||
array frame =
|
||||
slice(padded_x, {i * hop_length}, {i * hop_length + n_fft}, s);
|
||||
frames.push_back(multiply(frame, win, s));
|
||||
}
|
||||
|
||||
array stacked_frames = stack(frames, 0, s);
|
||||
|
||||
array stft_result = mlx::core::fft::rfftn(stacked_frames, {n_fft}, {-1}, s);
|
||||
|
||||
if (normalized) {
|
||||
array n_fft_array = full({1}, static_cast<float>(n_fft), float32, s);
|
||||
stft_result = divide(stft_result, sqrt(n_fft_array, s), s);
|
||||
}
|
||||
|
||||
if (onesided) {
|
||||
stft_result = slice(stft_result, {}, {n_fft / 2 + 1}, s);
|
||||
}
|
||||
|
||||
return stft_result;
|
||||
}
|
||||
|
||||
array istft(
|
||||
const array& stft_matrix,
|
||||
int hop_length = -1,
|
||||
int win_length = -1,
|
||||
const array& window,
|
||||
bool center = true,
|
||||
int length = -1,
|
||||
bool normalized = false,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
int n_fft = (stft_matrix.shape(-1) - 1) * 2;
|
||||
if (hop_length == -1)
|
||||
hop_length = n_fft / 4;
|
||||
if (win_length == -1)
|
||||
win_length = n_fft;
|
||||
|
||||
array win = (window.size() == 0) ? ones({win_length}, float32, s) : window;
|
||||
|
||||
if (win_length < n_fft) {
|
||||
int pad_left = (n_fft - win_length) / 2;
|
||||
int pad_right = n_fft - win_length - pad_left;
|
||||
|
||||
array left_pad = zeros({pad_left}, float32, s);
|
||||
array right_pad = zeros({pad_right}, float32, s);
|
||||
win = concatenate({left_pad, win, right_pad}, 0, s);
|
||||
}
|
||||
|
||||
array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s);
|
||||
|
||||
frames = multiply(frames, win, s);
|
||||
|
||||
int signal_length = (frames.shape(0) - 1) * hop_length + n_fft;
|
||||
array signal = zeros({signal_length}, float32, s);
|
||||
array window_sum = zeros({signal_length}, float32, s);
|
||||
|
||||
for (int i = 0; i < frames.shape(0); ++i) {
|
||||
array frame = reshape(slice(frames, {i}, {i + 1}, s), {n_fft}, s);
|
||||
array signal_slice =
|
||||
slice(signal, {i * hop_length}, {i * hop_length + n_fft}, s);
|
||||
array window_slice =
|
||||
slice(window_sum, {i * hop_length}, {i * hop_length + n_fft}, s);
|
||||
|
||||
signal_slice = add(signal_slice, frame, s);
|
||||
window_slice = add(window_slice, win, s);
|
||||
}
|
||||
|
||||
signal = divide(signal, window_sum, s);
|
||||
|
||||
if (center) {
|
||||
int pad_width = n_fft / 2;
|
||||
signal = slice(signal, {pad_width}, {signal.shape(0) - pad_width}, s);
|
||||
}
|
||||
|
||||
if (length > 0) {
|
||||
if (signal.shape(0) > length) {
|
||||
signal = slice(signal, {0}, {length}, s);
|
||||
} else if (signal.shape(0) < length) {
|
||||
int pad_length = length - signal.shape(0);
|
||||
array pad_array = zeros({pad_length}, signal.dtype(), s);
|
||||
signal = concatenate({signal, pad_array}, 0, s);
|
||||
}
|
||||
}
|
||||
|
||||
return signal;
|
||||
}
|
||||
|
||||
} // namespace mlx::core::stft
|
||||
|
Loading…
Reference in New Issue
Block a user