diff --git a/mlx/fft.cpp b/mlx/fft.cpp index dfb585dfc..a8af2eea8 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -191,10 +191,6 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) { return fft_impl(a, true, true, s); } -} // namespace mlx::core::fft - -namespace mlx::core::stft { - array stft( const array& x, int n_fft = 2048, @@ -216,19 +212,15 @@ array stft( if (win_length < n_fft) { int pad_left = (n_fft - win_length) / 2; int pad_right = n_fft - win_length - pad_left; - - array left_pad = zeros({pad_left}, float32, s); - array right_pad = zeros({pad_right}, float32, s); - win = concatenate({left_pad, win, right_pad}, 0, s); + win = mlx::core::pad( + win, {{pad_left, pad_right}}, array(0, float32), "constant", s); } array padded_x = x; if (center) { int pad_width = n_fft / 2; - - array left_pad = zeros({pad_width}, x.dtype(), s); - array right_pad = zeros({pad_width}, x.dtype(), s); - padded_x = concatenate({left_pad, padded_x, right_pad}, 0, s); + padded_x = mlx::core::pad( + padded_x, {{pad_width, pad_width}}, array(0, x.dtype()), pad_mode, s); } int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length; @@ -276,10 +268,8 @@ array istft( if (win_length < n_fft) { int pad_left = (n_fft - win_length) / 2; int pad_right = n_fft - win_length - pad_left; - - array left_pad = zeros({pad_left}, float32, s); - array right_pad = zeros({pad_right}, float32, s); - win = concatenate({left_pad, win, right_pad}, 0, s); + win = mlx::core::pad( + win, {{pad_left, pad_right}}, array(0, float32), "constant", s); } array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s); @@ -313,12 +303,12 @@ array istft( signal = slice(signal, {0}, {length}, s); } else if (signal.shape(0) < length) { int pad_length = length - signal.shape(0); - array pad_array = zeros({pad_length}, signal.dtype(), s); - signal = concatenate({signal, pad_array}, 0, s); + signal = mlx::core::pad( + signal, {{0, pad_length}}, array(0, signal.dtype()), "constant", s); } } return signal; } -} // namespace mlx::core::stft +} // namespace mlx::core::fft \ No newline at end of file