mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Formatted
This commit is contained in:
parent
9e8befbe8d
commit
f50cce83a5
@ -1,6 +1,7 @@
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def stft(
|
def stft(
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
n_fft: int = 2048,
|
n_fft: int = 2048,
|
||||||
@ -16,10 +17,10 @@ def stft(
|
|||||||
|
|
||||||
if hop_length is None:
|
if hop_length is None:
|
||||||
hop_length = n_fft // 4
|
hop_length = n_fft // 4
|
||||||
|
|
||||||
if win_length is None:
|
if win_length is None:
|
||||||
win_length = n_fft
|
win_length = n_fft
|
||||||
|
|
||||||
if window is None:
|
if window is None:
|
||||||
window = mx.ones(win_length)
|
window = mx.ones(win_length)
|
||||||
|
|
||||||
@ -34,24 +35,24 @@ def stft(
|
|||||||
|
|
||||||
n_frames = 1 + (x.shape[0] - n_fft) // hop_length
|
n_frames = 1 + (x.shape[0] - n_fft) // hop_length
|
||||||
|
|
||||||
frames = mx.stack([
|
frames = mx.stack(
|
||||||
x[i * hop_length : i * hop_length + n_fft] * window
|
[x[i * hop_length : i * hop_length + n_fft] * window for i in range(n_frames)]
|
||||||
for i in range(n_frames)
|
)
|
||||||
])
|
|
||||||
|
|
||||||
stft = mx.fft.fft(frames, n=n_fft, axis=-1)
|
stft = mx.fft.fft(frames, n=n_fft, axis=-1)
|
||||||
|
|
||||||
if normalized:
|
if normalized:
|
||||||
stft = stft / mx.sqrt(n_fft)
|
stft = stft / mx.sqrt(n_fft)
|
||||||
|
|
||||||
if onesided:
|
if onesided:
|
||||||
stft = stft[..., :n_fft//2 + 1]
|
stft = stft[..., : n_fft // 2 + 1]
|
||||||
|
|
||||||
if not return_complex:
|
if not return_complex:
|
||||||
stft = mx.stack([stft.real, stft.imag], axis=-1)
|
stft = mx.stack([stft.real, stft.imag], axis=-1)
|
||||||
|
|
||||||
return stft
|
return stft
|
||||||
|
|
||||||
|
|
||||||
def istft(
|
def istft(
|
||||||
stft_matrix: mx.array,
|
stft_matrix: mx.array,
|
||||||
hop_length: int = None,
|
hop_length: int = None,
|
||||||
@ -64,52 +65,52 @@ def istft(
|
|||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
if hop_length is None:
|
if hop_length is None:
|
||||||
hop_length = stft_matrix.shape[-2] // 4
|
hop_length = stft_matrix.shape[-2] // 4
|
||||||
|
|
||||||
if win_length is None:
|
if win_length is None:
|
||||||
win_length = stft_matrix.shape[-2]
|
win_length = stft_matrix.shape[-2]
|
||||||
|
|
||||||
if window is None:
|
if window is None:
|
||||||
window = mx.ones(win_length)
|
window = mx.ones(win_length)
|
||||||
|
|
||||||
if win_length < stft_matrix.shape[-2]:
|
if win_length < stft_matrix.shape[-2]:
|
||||||
pad_left = (stft_matrix.shape[-2] - win_length) // 2
|
pad_left = (stft_matrix.shape[-2] - win_length) // 2
|
||||||
pad_right = stft_matrix.shape[-2] - win_length - pad_left
|
pad_right = stft_matrix.shape[-2] - win_length - pad_left
|
||||||
window = mx.pad(window, [(pad_left, pad_right)])
|
window = mx.pad(window, [(pad_left, pad_right)])
|
||||||
|
|
||||||
if stft_matrix.shape[-1] == 2:
|
if stft_matrix.shape[-1] == 2:
|
||||||
stft_matrix = stft_matrix[..., 0] + 1j * stft_matrix[..., 1]
|
stft_matrix = stft_matrix[..., 0] + 1j * stft_matrix[..., 1]
|
||||||
|
|
||||||
if onesided:
|
if onesided:
|
||||||
n_fft = 2 * (stft_matrix.shape[-1] - 1)
|
n_fft = 2 * (stft_matrix.shape[-1] - 1)
|
||||||
full_stft = mx.zeros((*stft_matrix.shape[:-1], n_fft), dtype=stft_matrix.dtype)
|
full_stft = mx.zeros((*stft_matrix.shape[:-1], n_fft), dtype=stft_matrix.dtype)
|
||||||
full_stft[..., :stft_matrix.shape[-1]] = stft_matrix
|
full_stft[..., : stft_matrix.shape[-1]] = stft_matrix
|
||||||
full_stft[..., stft_matrix.shape[-1]:] = mx.conj(stft_matrix[..., -2:0:-1])
|
full_stft[..., stft_matrix.shape[-1] :] = mx.conj(stft_matrix[..., -2:0:-1])
|
||||||
stft_matrix = full_stft
|
stft_matrix = full_stft
|
||||||
|
|
||||||
frames = mx.fft.ifft(stft_matrix, n=stft_matrix.shape[-1], axis=-1)
|
frames = mx.fft.ifft(stft_matrix, n=stft_matrix.shape[-1], axis=-1)
|
||||||
|
|
||||||
if normalized:
|
if normalized:
|
||||||
frames = frames * mx.sqrt(frames.shape[-1])
|
frames = frames * mx.sqrt(frames.shape[-1])
|
||||||
|
|
||||||
frames = frames * window
|
frames = frames * window
|
||||||
|
|
||||||
signal_length = (frames.shape[0] - 1) * hop_length + frames.shape[1]
|
signal_length = (frames.shape[0] - 1) * hop_length + frames.shape[1]
|
||||||
signal = mx.zeros(signal_length, dtype=frames.dtype)
|
signal = mx.zeros(signal_length, dtype=frames.dtype)
|
||||||
|
|
||||||
for i in range(frames.shape[0]):
|
for i in range(frames.shape[0]):
|
||||||
signal[i * hop_length : i * hop_length + frames.shape[1]] += frames[i]
|
signal[i * hop_length : i * hop_length + frames.shape[1]] += frames[i]
|
||||||
|
|
||||||
window_sum = mx.zeros(signal_length, dtype=frames.dtype)
|
window_sum = mx.zeros(signal_length, dtype=frames.dtype)
|
||||||
for i in range(frames.shape[0]):
|
for i in range(frames.shape[0]):
|
||||||
window_sum[i * hop_length : i * hop_length + frames.shape[1]] += window
|
window_sum[i * hop_length : i * hop_length + frames.shape[1]] += window
|
||||||
|
|
||||||
signal = signal / window_sum
|
signal = signal / window_sum
|
||||||
|
|
||||||
if center:
|
if center:
|
||||||
pad_width = frames.shape[1] // 2
|
pad_width = frames.shape[1] // 2
|
||||||
signal = signal[pad_width:-pad_width]
|
signal = signal[pad_width:-pad_width]
|
||||||
|
|
||||||
if length is not None:
|
if length is not None:
|
||||||
signal = signal[:length]
|
signal = signal[:length]
|
||||||
|
|
||||||
return signal
|
return signal
|
||||||
|
Loading…
Reference in New Issue
Block a user