* Add MusicGen model

* add benchmarks

* change to from_pretrained

* symlinks

* add readme and requirements

* fix readme

* readme
This commit is contained in:
Alex Barron
2024-10-11 10:16:20 -07:00
committed by GitHub
parent 4360e7ccec
commit d72fdeb4ee
19 changed files with 722 additions and 245 deletions

View File

@@ -33,13 +33,14 @@ An example using the model:
```python
import mlx.core as mx
from utils import load, load_audio, save_audio
from encodec import EncodecModel
from utils import load_audio, save_audio
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).

View File

@@ -3,9 +3,10 @@
import time
import mlx.core as mx
from utils import load
model, processor = load("mlx-community/encodec-48khz-float32")
from encodec import EncodecModel
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
audio = mx.random.uniform(shape=(288000, 2))
feats, mask = processor(audio)

View File

@@ -10,7 +10,6 @@ from typing import Any, Dict, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
import encodec

View File

@@ -1,7 +1,10 @@
# Copyright © 2024 Apple Inc.
import functools
import json
import math
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union
import mlx.core as mx
@@ -669,3 +672,70 @@ class EncodecModel(nn.Module):
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
audio_values = audio_values[:, : padding_mask.shape[1]]
return audio_values
@classmethod
def from_pretrained(cls, path_or_repo: str):
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
chunk_length: Optional[int] = None,
chunk_stride: Optional[int] = None,
):
r"""
Prepare inputs for the EnCodec model.
Args:
raw_audio (mx.array or List[mx.array]): The sequence or batch of
sequences to be processed.
sampling_rate (int): The sampling rate at which the audio waveform
should be digitalized.
chunk_length (int, optional): The model's chunk length.
chunk_stride (int, optional): The model's chunk stride.
"""
if not isinstance(raw_audio, list):
raw_audio = [raw_audio]
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
max_length = max(array.shape[0] for array in raw_audio)
if chunk_length is not None:
max_length += chunk_length - (max_length % chunk_stride)
inputs = []
masks = []
for x in raw_audio:
length = x.shape[0]
mask = mx.ones((length,), dtype=mx.bool_)
difference = max_length - length
if difference > 0:
mask = mx.pad(mask, (0, difference))
x = mx.pad(x, ((0, difference), (0, 0)))
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)

View File

@@ -1,10 +1,12 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
from utils import load, load_audio, save_audio
from utils import load_audio, save_audio
from encodec import EncodecModel
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)

View File

@@ -3,9 +3,10 @@
import mlx.core as mx
import numpy as np
import torch
from datasets import Audio, load_dataset
from transformers import AutoProcessor, EncodecModel
from utils import load, load_audio, preprocess_audio
from transformers import AutoProcessor
from transformers import EncodecModel as PTEncodecModel
from encodec import EncodecModel, preprocess_audio
def compare_processors():
@@ -30,8 +31,8 @@ def compare_processors():
def compare_models():
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
mx_model, _ = load("mlx-community/encodec-48khz-float32")
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
np.random.seed(0)
audio_length = 190560

View File

@@ -1,16 +1,7 @@
# Copyright © 2024 Apple Inc.
import functools
import json
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Union
import mlx.core as mx
import numpy as np
from huggingface_hub import snapshot_download
import encodec
def save_audio(file: str, audio: mx.array, sampling_rate: int):
@@ -59,71 +50,3 @@ def load_audio(file: str, sampling_rate: int, channels: int):
out = mx.array(np.frombuffer(out, np.int16))
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
chunk_length: Optional[int] = None,
chunk_stride: Optional[int] = None,
):
r"""
Prepare inputs for the EnCodec model.
Args:
raw_audio (mx.array or List[mx.array]): The sequence or batch of
sequences to be processed.
sampling_rate (int): The sampling rate at which the audio waveform
should be digitalized.
chunk_length (int, optional): The model's chunk length.
chunk_stride (int, optional): The model's chunk stride.
"""
if not isinstance(raw_audio, list):
raw_audio = [raw_audio]
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
max_length = max(array.shape[0] for array in raw_audio)
if chunk_length is not None:
max_length += chunk_length - (max_length % chunk_stride)
inputs = []
masks = []
for x in raw_audio:
length = x.shape[0]
mask = mx.ones((length,), dtype=mx.bool_)
difference = max_length - length
if difference > 0:
mask = mx.pad(mask, (0, difference))
x = mx.pad(x, ((0, difference), (0, 0)))
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)
def load(path_or_repo):
"""
Load the model and audo preprocessor.
"""
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = encodec.EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor