mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00

* initial encodec * works * nits * use fast group norm * fix for rnn layer * fix mlx version * use custom LSTM kernel * audio encodec * fix example, support batched inference * nits
130 lines
3.7 KiB
Python
130 lines
3.7 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import functools
|
|
import json
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from typing import List, Optional, Union
|
|
|
|
import mlx.core as mx
|
|
import numpy as np
|
|
from huggingface_hub import snapshot_download
|
|
|
|
import encodec
|
|
|
|
|
|
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
|
"""
|
|
Save audio to a wave (.wav) file.
|
|
"""
|
|
from scipy.io.wavfile import write
|
|
|
|
audio = (audio * 32767).astype(mx.int16)
|
|
write(file, sampling_rate, np.array(audio))
|
|
|
|
|
|
def load_audio(file: str, sampling_rate: int, channels: int):
|
|
"""
|
|
Read audio into an mx.array, resampling if necessary.
|
|
|
|
Args:
|
|
file (str): The audio file to open.
|
|
sampling_rate (int): The sample rate to resample the audio at if needed.
|
|
channels (int): The number of audio channels.
|
|
|
|
Returns:
|
|
An mx.array containing the audio waveform in float32.
|
|
"""
|
|
from subprocess import CalledProcessError, run
|
|
|
|
# This launches a subprocess to decode audio while down-mixing
|
|
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
|
# fmt: off
|
|
cmd = [
|
|
"ffmpeg",
|
|
"-nostdin",
|
|
"-threads", "0",
|
|
"-i", file,
|
|
"-f", "s16le",
|
|
"-ac", str(channels),
|
|
"-acodec", "pcm_s16le",
|
|
"-ar", str(sampling_rate),
|
|
"-"
|
|
]
|
|
# fmt: on
|
|
try:
|
|
out = run(cmd, capture_output=True, check=True).stdout
|
|
except CalledProcessError as e:
|
|
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
|
|
|
out = mx.array(np.frombuffer(out, np.int16))
|
|
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
|
|
|
|
|
|
def preprocess_audio(
|
|
raw_audio: Union[mx.array, List[mx.array]],
|
|
sampling_rate: int = 24000,
|
|
chunk_length: Optional[int] = None,
|
|
chunk_stride: Optional[int] = None,
|
|
):
|
|
r"""
|
|
Prepare inputs for the EnCodec model.
|
|
|
|
Args:
|
|
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
|
sequences to be processed.
|
|
sampling_rate (int): The sampling rate at which the audio waveform
|
|
should be digitalized.
|
|
chunk_length (int, optional): The model's chunk length.
|
|
chunk_stride (int, optional): The model's chunk stride.
|
|
"""
|
|
if not isinstance(raw_audio, list):
|
|
raw_audio = [raw_audio]
|
|
|
|
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
|
|
|
max_length = max(array.shape[0] for array in raw_audio)
|
|
if chunk_length is not None:
|
|
max_length += chunk_length - (max_length % chunk_stride)
|
|
|
|
inputs = []
|
|
masks = []
|
|
for x in raw_audio:
|
|
length = x.shape[0]
|
|
mask = mx.ones((length,), dtype=mx.bool_)
|
|
difference = max_length - length
|
|
if difference > 0:
|
|
mask = mx.pad(mask, (0, difference))
|
|
x = mx.pad(x, ((0, difference), (0, 0)))
|
|
inputs.append(x)
|
|
masks.append(mask)
|
|
return mx.stack(inputs), mx.stack(masks)
|
|
|
|
|
|
def load(path_or_repo):
|
|
"""
|
|
Load the model and audo preprocessor.
|
|
"""
|
|
path = Path(path_or_repo)
|
|
if not path.exists():
|
|
path = Path(
|
|
snapshot_download(
|
|
repo_id=path_or_repo,
|
|
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
|
)
|
|
)
|
|
|
|
with open(path / "config.json", "r") as f:
|
|
config = SimpleNamespace(**json.load(f))
|
|
|
|
model = encodec.EncodecModel(config)
|
|
model.load_weights(str(path / "model.safetensors"))
|
|
processor = functools.partial(
|
|
preprocess_audio,
|
|
sampling_rate=config.sampling_rate,
|
|
chunk_length=model.chunk_length,
|
|
chunk_stride=model.chunk_stride,
|
|
)
|
|
mx.eval(model)
|
|
return model, processor
|