Added docstrings

This commit is contained in:
paramthakkar123 2025-04-17 08:20:32 +05:30
parent f50cce83a5
commit 91097e6179

View File

@ -1,5 +1,4 @@
import mlx.core as mx
import numpy as np
def stft(
@ -14,6 +13,25 @@ def stft(
onesided: bool = True,
return_complex: bool = True,
) -> mx.array:
"""
Computes the Short-Time Fourier Transform (STFT) of a signal.
Args:
x (mx.array): Input signal array.
n_fft (int): Number of FFT points. Default is 2048.
hop_length (int, optional): Number of samples between successive frames.
Defaults to `n_fft // 4`.
win_length (int, optional): Window size. Defaults to `n_fft`.
window (mx.array, optional): Window function. Defaults to a rectangular window.
center (bool): If True, pads the input signal so that the frame is centered. Default is True.
pad_mode (str): Padding mode to use if `center` is True. Default is "reflect".
normalized (bool): If True, normalizes the STFT by the square root of `n_fft`. Default is False.
onesided (bool): If True, returns only the positive frequencies. Default is True.
return_complex (bool): If True, returns a complex-valued STFT. If False, returns real and imaginary parts separately. Default is True.
Returns:
mx.array: The STFT of the input signal. The shape depends on the input and parameters.
"""
if hop_length is None:
hop_length = n_fft // 4
@ -63,6 +81,23 @@ def istft(
normalized: bool = False,
onesided: bool = True,
) -> mx.array:
"""
Computes the inverse Short-Time Fourier Transform (iSTFT) of a signal.
Args:
stft_matrix (mx.array): STFT matrix (output of `stft`).
hop_length (int, optional): Number of samples between successive frames.
Defaults to `stft_matrix.shape[-2] // 4`.
win_length (int, optional): Window size. Defaults to `stft_matrix.shape[-2]`.
window (mx.array, optional): Window function. Defaults to a rectangular window.
center (bool): If True, removes padding added during STFT computation. Default is True.
length (int, optional): Length of the output signal. If provided, the output is trimmed or zero-padded to this length.
normalized (bool): If True, normalizes the iSTFT by the square root of the FFT size. Default is False.
onesided (bool): If True, assumes the input STFT matrix is onesided. Default is True.
Returns:
mx.array: The reconstructed time-domain signal.
"""
if hop_length is None:
hop_length = stft_matrix.shape[-2] // 4