mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Added docstrings
This commit is contained in:
parent
f50cce83a5
commit
91097e6179
@ -1,5 +1,4 @@
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def stft(
|
def stft(
|
||||||
@ -14,6 +13,25 @@ def stft(
|
|||||||
onesided: bool = True,
|
onesided: bool = True,
|
||||||
return_complex: bool = True,
|
return_complex: bool = True,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Computes the Short-Time Fourier Transform (STFT) of a signal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (mx.array): Input signal array.
|
||||||
|
n_fft (int): Number of FFT points. Default is 2048.
|
||||||
|
hop_length (int, optional): Number of samples between successive frames.
|
||||||
|
Defaults to `n_fft // 4`.
|
||||||
|
win_length (int, optional): Window size. Defaults to `n_fft`.
|
||||||
|
window (mx.array, optional): Window function. Defaults to a rectangular window.
|
||||||
|
center (bool): If True, pads the input signal so that the frame is centered. Default is True.
|
||||||
|
pad_mode (str): Padding mode to use if `center` is True. Default is "reflect".
|
||||||
|
normalized (bool): If True, normalizes the STFT by the square root of `n_fft`. Default is False.
|
||||||
|
onesided (bool): If True, returns only the positive frequencies. Default is True.
|
||||||
|
return_complex (bool): If True, returns a complex-valued STFT. If False, returns real and imaginary parts separately. Default is True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mx.array: The STFT of the input signal. The shape depends on the input and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
if hop_length is None:
|
if hop_length is None:
|
||||||
hop_length = n_fft // 4
|
hop_length = n_fft // 4
|
||||||
@ -63,6 +81,23 @@ def istft(
|
|||||||
normalized: bool = False,
|
normalized: bool = False,
|
||||||
onesided: bool = True,
|
onesided: bool = True,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Computes the inverse Short-Time Fourier Transform (iSTFT) of a signal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stft_matrix (mx.array): STFT matrix (output of `stft`).
|
||||||
|
hop_length (int, optional): Number of samples between successive frames.
|
||||||
|
Defaults to `stft_matrix.shape[-2] // 4`.
|
||||||
|
win_length (int, optional): Window size. Defaults to `stft_matrix.shape[-2]`.
|
||||||
|
window (mx.array, optional): Window function. Defaults to a rectangular window.
|
||||||
|
center (bool): If True, removes padding added during STFT computation. Default is True.
|
||||||
|
length (int, optional): Length of the output signal. If provided, the output is trimmed or zero-padded to this length.
|
||||||
|
normalized (bool): If True, normalizes the iSTFT by the square root of the FFT size. Default is False.
|
||||||
|
onesided (bool): If True, assumes the input STFT matrix is onesided. Default is True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mx.array: The reconstructed time-domain signal.
|
||||||
|
"""
|
||||||
if hop_length is None:
|
if hop_length is None:
|
||||||
hop_length = stft_matrix.shape[-2] // 4
|
hop_length = stft_matrix.shape[-2] // 4
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user