* Add MusicGen model

* add benchmarks

* change to from_pretrained

* symlinks

* add readme and requirements

* fix readme

* readme
This commit is contained in:
Alex Barron
2024-10-11 10:16:20 -07:00
committed by GitHub
parent 4360e7ccec
commit d72fdeb4ee
19 changed files with 722 additions and 245 deletions

View File

@@ -1,7 +1,10 @@
# Copyright © 2024 Apple Inc.
import functools
import json
import math
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union
import mlx.core as mx
@@ -669,3 +672,70 @@ class EncodecModel(nn.Module):
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
audio_values = audio_values[:, : padding_mask.shape[1]]
return audio_values
@classmethod
def from_pretrained(cls, path_or_repo: str):
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
chunk_length: Optional[int] = None,
chunk_stride: Optional[int] = None,
):
r"""
Prepare inputs for the EnCodec model.
Args:
raw_audio (mx.array or List[mx.array]): The sequence or batch of
sequences to be processed.
sampling_rate (int): The sampling rate at which the audio waveform
should be digitalized.
chunk_length (int, optional): The model's chunk length.
chunk_stride (int, optional): The model's chunk stride.
"""
if not isinstance(raw_audio, list):
raw_audio = [raw_audio]
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
max_length = max(array.shape[0] for array in raw_audio)
if chunk_length is not None:
max_length += chunk_length - (max_length % chunk_stride)
inputs = []
masks = []
for x in raw_audio:
length = x.shape[0]
mask = mx.ones((length,), dtype=mx.bool_)
difference = max_length - length
if difference > 0:
mask = mx.pad(mask, (0, difference))
x = mx.pad(x, ((0, difference), (0, 0)))
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)