mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Added docstrings
This commit is contained in:
parent
f50cce83a5
commit
91097e6179
@ -1,5 +1,4 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def stft(
|
||||
@ -14,6 +13,25 @@ def stft(
|
||||
onesided: bool = True,
|
||||
return_complex: bool = True,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Computes the Short-Time Fourier Transform (STFT) of a signal.
|
||||
|
||||
Args:
|
||||
x (mx.array): Input signal array.
|
||||
n_fft (int): Number of FFT points. Default is 2048.
|
||||
hop_length (int, optional): Number of samples between successive frames.
|
||||
Defaults to `n_fft // 4`.
|
||||
win_length (int, optional): Window size. Defaults to `n_fft`.
|
||||
window (mx.array, optional): Window function. Defaults to a rectangular window.
|
||||
center (bool): If True, pads the input signal so that the frame is centered. Default is True.
|
||||
pad_mode (str): Padding mode to use if `center` is True. Default is "reflect".
|
||||
normalized (bool): If True, normalizes the STFT by the square root of `n_fft`. Default is False.
|
||||
onesided (bool): If True, returns only the positive frequencies. Default is True.
|
||||
return_complex (bool): If True, returns a complex-valued STFT. If False, returns real and imaginary parts separately. Default is True.
|
||||
|
||||
Returns:
|
||||
mx.array: The STFT of the input signal. The shape depends on the input and parameters.
|
||||
"""
|
||||
|
||||
if hop_length is None:
|
||||
hop_length = n_fft // 4
|
||||
@ -63,6 +81,23 @@ def istft(
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Computes the inverse Short-Time Fourier Transform (iSTFT) of a signal.
|
||||
|
||||
Args:
|
||||
stft_matrix (mx.array): STFT matrix (output of `stft`).
|
||||
hop_length (int, optional): Number of samples between successive frames.
|
||||
Defaults to `stft_matrix.shape[-2] // 4`.
|
||||
win_length (int, optional): Window size. Defaults to `stft_matrix.shape[-2]`.
|
||||
window (mx.array, optional): Window function. Defaults to a rectangular window.
|
||||
center (bool): If True, removes padding added during STFT computation. Default is True.
|
||||
length (int, optional): Length of the output signal. If provided, the output is trimmed or zero-padded to this length.
|
||||
normalized (bool): If True, normalizes the iSTFT by the square root of the FFT size. Default is False.
|
||||
onesided (bool): If True, assumes the input STFT matrix is onesided. Default is True.
|
||||
|
||||
Returns:
|
||||
mx.array: The reconstructed time-domain signal.
|
||||
"""
|
||||
if hop_length is None:
|
||||
hop_length = stft_matrix.shape[-2] // 4
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user