diff --git a/python/mlx/stft.py b/python/mlx/stft.py index ac2cbf349..a4abfdc87 100644 --- a/python/mlx/stft.py +++ b/python/mlx/stft.py @@ -1,5 +1,4 @@ import mlx.core as mx -import numpy as np def stft( @@ -14,6 +13,25 @@ def stft( onesided: bool = True, return_complex: bool = True, ) -> mx.array: + """ + Computes the Short-Time Fourier Transform (STFT) of a signal. + + Args: + x (mx.array): Input signal array. + n_fft (int): Number of FFT points. Default is 2048. + hop_length (int, optional): Number of samples between successive frames. + Defaults to `n_fft // 4`. + win_length (int, optional): Window size. Defaults to `n_fft`. + window (mx.array, optional): Window function. Defaults to a rectangular window. + center (bool): If True, pads the input signal so that the frame is centered. Default is True. + pad_mode (str): Padding mode to use if `center` is True. Default is "reflect". + normalized (bool): If True, normalizes the STFT by the square root of `n_fft`. Default is False. + onesided (bool): If True, returns only the positive frequencies. Default is True. + return_complex (bool): If True, returns a complex-valued STFT. If False, returns real and imaginary parts separately. Default is True. + + Returns: + mx.array: The STFT of the input signal. The shape depends on the input and parameters. + """ if hop_length is None: hop_length = n_fft // 4 @@ -63,6 +81,23 @@ def istft( normalized: bool = False, onesided: bool = True, ) -> mx.array: + """ + Computes the inverse Short-Time Fourier Transform (iSTFT) of a signal. + + Args: + stft_matrix (mx.array): STFT matrix (output of `stft`). + hop_length (int, optional): Number of samples between successive frames. + Defaults to `stft_matrix.shape[-2] // 4`. + win_length (int, optional): Window size. Defaults to `stft_matrix.shape[-2]`. + window (mx.array, optional): Window function. Defaults to a rectangular window. + center (bool): If True, removes padding added during STFT computation. Default is True. + length (int, optional): Length of the output signal. If provided, the output is trimmed or zero-padded to this length. + normalized (bool): If True, normalizes the iSTFT by the square root of the FFT size. Default is False. + onesided (bool): If True, assumes the input STFT matrix is onesided. Default is True. + + Returns: + mx.array: The reconstructed time-domain signal. + """ if hop_length is None: hop_length = stft_matrix.shape[-2] // 4