mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
MusicGen (#1020)
* Add MusicGen model * add benchmarks * change to from_pretrained * symlinks * add readme and requirements * fix readme * readme
This commit is contained in:
@@ -33,13 +33,14 @@ An example using the model:
|
||||
|
||||
```python
|
||||
import mlx.core as mx
|
||||
from utils import load, load_audio, save_audio
|
||||
from encodec import EncodecModel
|
||||
from utils import load_audio, save_audio
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
# Load an audio file
|
||||
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
|
||||
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
|
||||
|
||||
# Preprocess the audio (this can also be a list of arrays for batched
|
||||
# processing).
|
||||
|
@@ -3,9 +3,10 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from utils import load
|
||||
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
from encodec import EncodecModel
|
||||
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
audio = mx.random.uniform(shape=(288000, 2))
|
||||
feats, mask = processor(audio)
|
||||
|
@@ -10,7 +10,6 @@ from typing import Any, Dict, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
import encodec
|
||||
|
||||
|
@@ -1,7 +1,10 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import functools
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -669,3 +672,70 @@ class EncodecModel(nn.Module):
|
||||
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
|
||||
audio_values = audio_values[:, : padding_mask.shape[1]]
|
||||
return audio_values
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_or_repo: str):
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
model = EncodecModel(config)
|
||||
model.load_weights(str(path / "model.safetensors"))
|
||||
processor = functools.partial(
|
||||
preprocess_audio,
|
||||
sampling_rate=config.sampling_rate,
|
||||
chunk_length=model.chunk_length,
|
||||
chunk_stride=model.chunk_stride,
|
||||
)
|
||||
mx.eval(model)
|
||||
return model, processor
|
||||
|
||||
|
||||
def preprocess_audio(
|
||||
raw_audio: Union[mx.array, List[mx.array]],
|
||||
sampling_rate: int = 24000,
|
||||
chunk_length: Optional[int] = None,
|
||||
chunk_stride: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Prepare inputs for the EnCodec model.
|
||||
|
||||
Args:
|
||||
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
||||
sequences to be processed.
|
||||
sampling_rate (int): The sampling rate at which the audio waveform
|
||||
should be digitalized.
|
||||
chunk_length (int, optional): The model's chunk length.
|
||||
chunk_stride (int, optional): The model's chunk stride.
|
||||
"""
|
||||
if not isinstance(raw_audio, list):
|
||||
raw_audio = [raw_audio]
|
||||
|
||||
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
||||
|
||||
max_length = max(array.shape[0] for array in raw_audio)
|
||||
if chunk_length is not None:
|
||||
max_length += chunk_length - (max_length % chunk_stride)
|
||||
|
||||
inputs = []
|
||||
masks = []
|
||||
for x in raw_audio:
|
||||
length = x.shape[0]
|
||||
mask = mx.ones((length,), dtype=mx.bool_)
|
||||
difference = max_length - length
|
||||
if difference > 0:
|
||||
mask = mx.pad(mask, (0, difference))
|
||||
x = mx.pad(x, ((0, difference), (0, 0)))
|
||||
inputs.append(x)
|
||||
masks.append(mask)
|
||||
return mx.stack(inputs), mx.stack(masks)
|
||||
|
@@ -1,10 +1,12 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from utils import load, load_audio, save_audio
|
||||
from utils import load_audio, save_audio
|
||||
|
||||
from encodec import EncodecModel
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
# Load an audio file
|
||||
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
||||
|
@@ -3,9 +3,10 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Audio, load_dataset
|
||||
from transformers import AutoProcessor, EncodecModel
|
||||
from utils import load, load_audio, preprocess_audio
|
||||
from transformers import AutoProcessor
|
||||
from transformers import EncodecModel as PTEncodecModel
|
||||
|
||||
from encodec import EncodecModel, preprocess_audio
|
||||
|
||||
|
||||
def compare_processors():
|
||||
@@ -30,8 +31,8 @@ def compare_processors():
|
||||
|
||||
|
||||
def compare_models():
|
||||
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||
mx_model, _ = load("mlx-community/encodec-48khz-float32")
|
||||
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
np.random.seed(0)
|
||||
audio_length = 190560
|
||||
|
@@ -1,16 +1,7 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import functools
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import encodec
|
||||
|
||||
|
||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||
@@ -59,71 +50,3 @@ def load_audio(file: str, sampling_rate: int, channels: int):
|
||||
|
||||
out = mx.array(np.frombuffer(out, np.int16))
|
||||
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
|
||||
|
||||
|
||||
def preprocess_audio(
|
||||
raw_audio: Union[mx.array, List[mx.array]],
|
||||
sampling_rate: int = 24000,
|
||||
chunk_length: Optional[int] = None,
|
||||
chunk_stride: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Prepare inputs for the EnCodec model.
|
||||
|
||||
Args:
|
||||
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
||||
sequences to be processed.
|
||||
sampling_rate (int): The sampling rate at which the audio waveform
|
||||
should be digitalized.
|
||||
chunk_length (int, optional): The model's chunk length.
|
||||
chunk_stride (int, optional): The model's chunk stride.
|
||||
"""
|
||||
if not isinstance(raw_audio, list):
|
||||
raw_audio = [raw_audio]
|
||||
|
||||
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
||||
|
||||
max_length = max(array.shape[0] for array in raw_audio)
|
||||
if chunk_length is not None:
|
||||
max_length += chunk_length - (max_length % chunk_stride)
|
||||
|
||||
inputs = []
|
||||
masks = []
|
||||
for x in raw_audio:
|
||||
length = x.shape[0]
|
||||
mask = mx.ones((length,), dtype=mx.bool_)
|
||||
difference = max_length - length
|
||||
if difference > 0:
|
||||
mask = mx.pad(mask, (0, difference))
|
||||
x = mx.pad(x, ((0, difference), (0, 0)))
|
||||
inputs.append(x)
|
||||
masks.append(mask)
|
||||
return mx.stack(inputs), mx.stack(masks)
|
||||
|
||||
|
||||
def load(path_or_repo):
|
||||
"""
|
||||
Load the model and audo preprocessor.
|
||||
"""
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
model = encodec.EncodecModel(config)
|
||||
model.load_weights(str(path / "model.safetensors"))
|
||||
processor = functools.partial(
|
||||
preprocess_audio,
|
||||
sampling_rate=config.sampling_rate,
|
||||
chunk_length=model.chunk_length,
|
||||
chunk_stride=model.chunk_stride,
|
||||
)
|
||||
mx.eval(model)
|
||||
return model, processor
|
||||
|
Reference in New Issue
Block a user