mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
changed to mx::pad and added stft to a single fft namespace
This commit is contained in:
parent
faf76a8133
commit
c92a2bc679
28
mlx/fft.cpp
28
mlx/fft.cpp
@ -191,10 +191,6 @@ array irfftn(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
return fft_impl(a, true, true, s);
|
return fft_impl(a, true, true, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::fft
|
|
||||||
|
|
||||||
namespace mlx::core::stft {
|
|
||||||
|
|
||||||
array stft(
|
array stft(
|
||||||
const array& x,
|
const array& x,
|
||||||
int n_fft = 2048,
|
int n_fft = 2048,
|
||||||
@ -216,19 +212,15 @@ array stft(
|
|||||||
if (win_length < n_fft) {
|
if (win_length < n_fft) {
|
||||||
int pad_left = (n_fft - win_length) / 2;
|
int pad_left = (n_fft - win_length) / 2;
|
||||||
int pad_right = n_fft - win_length - pad_left;
|
int pad_right = n_fft - win_length - pad_left;
|
||||||
|
win = mlx::core::pad(
|
||||||
array left_pad = zeros({pad_left}, float32, s);
|
win, {{pad_left, pad_right}}, array(0, float32), "constant", s);
|
||||||
array right_pad = zeros({pad_right}, float32, s);
|
|
||||||
win = concatenate({left_pad, win, right_pad}, 0, s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array padded_x = x;
|
array padded_x = x;
|
||||||
if (center) {
|
if (center) {
|
||||||
int pad_width = n_fft / 2;
|
int pad_width = n_fft / 2;
|
||||||
|
padded_x = mlx::core::pad(
|
||||||
array left_pad = zeros({pad_width}, x.dtype(), s);
|
padded_x, {{pad_width, pad_width}}, array(0, x.dtype()), pad_mode, s);
|
||||||
array right_pad = zeros({pad_width}, x.dtype(), s);
|
|
||||||
padded_x = concatenate({left_pad, padded_x, right_pad}, 0, s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length;
|
int n_frames = 1 + (padded_x.shape(0) - n_fft) / hop_length;
|
||||||
@ -276,10 +268,8 @@ array istft(
|
|||||||
if (win_length < n_fft) {
|
if (win_length < n_fft) {
|
||||||
int pad_left = (n_fft - win_length) / 2;
|
int pad_left = (n_fft - win_length) / 2;
|
||||||
int pad_right = n_fft - win_length - pad_left;
|
int pad_right = n_fft - win_length - pad_left;
|
||||||
|
win = mlx::core::pad(
|
||||||
array left_pad = zeros({pad_left}, float32, s);
|
win, {{pad_left, pad_right}}, array(0, float32), "constant", s);
|
||||||
array right_pad = zeros({pad_right}, float32, s);
|
|
||||||
win = concatenate({left_pad, win, right_pad}, 0, s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s);
|
array frames = mlx::core::fft::irfftn(stft_matrix, {n_fft}, {-1}, s);
|
||||||
@ -313,12 +303,12 @@ array istft(
|
|||||||
signal = slice(signal, {0}, {length}, s);
|
signal = slice(signal, {0}, {length}, s);
|
||||||
} else if (signal.shape(0) < length) {
|
} else if (signal.shape(0) < length) {
|
||||||
int pad_length = length - signal.shape(0);
|
int pad_length = length - signal.shape(0);
|
||||||
array pad_array = zeros({pad_length}, signal.dtype(), s);
|
signal = mlx::core::pad(
|
||||||
signal = concatenate({signal, pad_array}, 0, s);
|
signal, {{0, pad_length}}, array(0, signal.dtype()), "constant", s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return signal;
|
return signal;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::stft
|
} // namespace mlx::core::fft
|
Loading…
Reference in New Issue
Block a user