Formatted

This commit is contained in:
paramthakkar123 2025-04-17 00:24:49 +05:30
parent 9e8befbe8d
commit f50cce83a5

View File

@ -1,6 +1,7 @@
import mlx.core as mx import mlx.core as mx
import numpy as np import numpy as np
def stft( def stft(
x: mx.array, x: mx.array,
n_fft: int = 2048, n_fft: int = 2048,
@ -16,10 +17,10 @@ def stft(
if hop_length is None: if hop_length is None:
hop_length = n_fft // 4 hop_length = n_fft // 4
if win_length is None: if win_length is None:
win_length = n_fft win_length = n_fft
if window is None: if window is None:
window = mx.ones(win_length) window = mx.ones(win_length)
@ -34,24 +35,24 @@ def stft(
n_frames = 1 + (x.shape[0] - n_fft) // hop_length n_frames = 1 + (x.shape[0] - n_fft) // hop_length
frames = mx.stack([ frames = mx.stack(
x[i * hop_length : i * hop_length + n_fft] * window [x[i * hop_length : i * hop_length + n_fft] * window for i in range(n_frames)]
for i in range(n_frames) )
])
stft = mx.fft.fft(frames, n=n_fft, axis=-1) stft = mx.fft.fft(frames, n=n_fft, axis=-1)
if normalized: if normalized:
stft = stft / mx.sqrt(n_fft) stft = stft / mx.sqrt(n_fft)
if onesided: if onesided:
stft = stft[..., :n_fft//2 + 1] stft = stft[..., : n_fft // 2 + 1]
if not return_complex: if not return_complex:
stft = mx.stack([stft.real, stft.imag], axis=-1) stft = mx.stack([stft.real, stft.imag], axis=-1)
return stft return stft
def istft( def istft(
stft_matrix: mx.array, stft_matrix: mx.array,
hop_length: int = None, hop_length: int = None,
@ -64,52 +65,52 @@ def istft(
) -> mx.array: ) -> mx.array:
if hop_length is None: if hop_length is None:
hop_length = stft_matrix.shape[-2] // 4 hop_length = stft_matrix.shape[-2] // 4
if win_length is None: if win_length is None:
win_length = stft_matrix.shape[-2] win_length = stft_matrix.shape[-2]
if window is None: if window is None:
window = mx.ones(win_length) window = mx.ones(win_length)
if win_length < stft_matrix.shape[-2]: if win_length < stft_matrix.shape[-2]:
pad_left = (stft_matrix.shape[-2] - win_length) // 2 pad_left = (stft_matrix.shape[-2] - win_length) // 2
pad_right = stft_matrix.shape[-2] - win_length - pad_left pad_right = stft_matrix.shape[-2] - win_length - pad_left
window = mx.pad(window, [(pad_left, pad_right)]) window = mx.pad(window, [(pad_left, pad_right)])
if stft_matrix.shape[-1] == 2: if stft_matrix.shape[-1] == 2:
stft_matrix = stft_matrix[..., 0] + 1j * stft_matrix[..., 1] stft_matrix = stft_matrix[..., 0] + 1j * stft_matrix[..., 1]
if onesided: if onesided:
n_fft = 2 * (stft_matrix.shape[-1] - 1) n_fft = 2 * (stft_matrix.shape[-1] - 1)
full_stft = mx.zeros((*stft_matrix.shape[:-1], n_fft), dtype=stft_matrix.dtype) full_stft = mx.zeros((*stft_matrix.shape[:-1], n_fft), dtype=stft_matrix.dtype)
full_stft[..., :stft_matrix.shape[-1]] = stft_matrix full_stft[..., : stft_matrix.shape[-1]] = stft_matrix
full_stft[..., stft_matrix.shape[-1]:] = mx.conj(stft_matrix[..., -2:0:-1]) full_stft[..., stft_matrix.shape[-1] :] = mx.conj(stft_matrix[..., -2:0:-1])
stft_matrix = full_stft stft_matrix = full_stft
frames = mx.fft.ifft(stft_matrix, n=stft_matrix.shape[-1], axis=-1) frames = mx.fft.ifft(stft_matrix, n=stft_matrix.shape[-1], axis=-1)
if normalized: if normalized:
frames = frames * mx.sqrt(frames.shape[-1]) frames = frames * mx.sqrt(frames.shape[-1])
frames = frames * window frames = frames * window
signal_length = (frames.shape[0] - 1) * hop_length + frames.shape[1] signal_length = (frames.shape[0] - 1) * hop_length + frames.shape[1]
signal = mx.zeros(signal_length, dtype=frames.dtype) signal = mx.zeros(signal_length, dtype=frames.dtype)
for i in range(frames.shape[0]): for i in range(frames.shape[0]):
signal[i * hop_length : i * hop_length + frames.shape[1]] += frames[i] signal[i * hop_length : i * hop_length + frames.shape[1]] += frames[i]
window_sum = mx.zeros(signal_length, dtype=frames.dtype) window_sum = mx.zeros(signal_length, dtype=frames.dtype)
for i in range(frames.shape[0]): for i in range(frames.shape[0]):
window_sum[i * hop_length : i * hop_length + frames.shape[1]] += window window_sum[i * hop_length : i * hop_length + frames.shape[1]] += window
signal = signal / window_sum signal = signal / window_sum
if center: if center:
pad_width = frames.shape[1] // 2 pad_width = frames.shape[1] // 2
signal = signal[pad_width:-pad_width] signal = signal[pad_width:-pad_width]
if length is not None: if length is not None:
signal = signal[:length] signal = signal[:length]
return signal return signal