mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Formatted
This commit is contained in:
parent
9e8befbe8d
commit
f50cce83a5
@ -1,6 +1,7 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def stft(
|
||||
x: mx.array,
|
||||
n_fft: int = 2048,
|
||||
@ -34,10 +35,9 @@ def stft(
|
||||
|
||||
n_frames = 1 + (x.shape[0] - n_fft) // hop_length
|
||||
|
||||
frames = mx.stack([
|
||||
x[i * hop_length : i * hop_length + n_fft] * window
|
||||
for i in range(n_frames)
|
||||
])
|
||||
frames = mx.stack(
|
||||
[x[i * hop_length : i * hop_length + n_fft] * window for i in range(n_frames)]
|
||||
)
|
||||
|
||||
stft = mx.fft.fft(frames, n=n_fft, axis=-1)
|
||||
|
||||
@ -45,13 +45,14 @@ def stft(
|
||||
stft = stft / mx.sqrt(n_fft)
|
||||
|
||||
if onesided:
|
||||
stft = stft[..., :n_fft//2 + 1]
|
||||
stft = stft[..., : n_fft // 2 + 1]
|
||||
|
||||
if not return_complex:
|
||||
stft = mx.stack([stft.real, stft.imag], axis=-1)
|
||||
|
||||
return stft
|
||||
|
||||
|
||||
def istft(
|
||||
stft_matrix: mx.array,
|
||||
hop_length: int = None,
|
||||
@ -82,8 +83,8 @@ def istft(
|
||||
if onesided:
|
||||
n_fft = 2 * (stft_matrix.shape[-1] - 1)
|
||||
full_stft = mx.zeros((*stft_matrix.shape[:-1], n_fft), dtype=stft_matrix.dtype)
|
||||
full_stft[..., :stft_matrix.shape[-1]] = stft_matrix
|
||||
full_stft[..., stft_matrix.shape[-1]:] = mx.conj(stft_matrix[..., -2:0:-1])
|
||||
full_stft[..., : stft_matrix.shape[-1]] = stft_matrix
|
||||
full_stft[..., stft_matrix.shape[-1] :] = mx.conj(stft_matrix[..., -2:0:-1])
|
||||
stft_matrix = full_stft
|
||||
|
||||
frames = mx.fft.ifft(stft_matrix, n=stft_matrix.shape[-1], axis=-1)
|
||||
|
Loading…
Reference in New Issue
Block a user