add benchmarks

This commit is contained in:
Alex Barron
2024-10-08 14:54:09 -07:00
parent d301705a2f
commit 9432f1a643
6 changed files with 58 additions and 81 deletions

View File

@@ -33,7 +33,8 @@ An example using the model:
```python
import mlx.core as mx
from utils import load, load_audio, save_audio
from encodec import load
from utils import load_audio, save_audio
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")

View File

@@ -1,7 +1,9 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
from utils import load, load_audio, save_audio
from utils import load_audio, save_audio
from encodec import load
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")

View File

@@ -3,9 +3,9 @@
import mlx.core as mx
import numpy as np
import torch
from datasets import Audio, load_dataset
from transformers import AutoProcessor, EncodecModel
from utils import load, load_audio, preprocess_audio
from encodec import load, preprocess_audio
def compare_processors():

View File

@@ -1,16 +1,7 @@
# Copyright © 2024 Apple Inc.
import functools
import json
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Union
import mlx.core as mx
import numpy as np
from huggingface_hub import snapshot_download
import encodec
def save_audio(file: str, audio: mx.array, sampling_rate: int):
@@ -59,71 +50,3 @@ def load_audio(file: str, sampling_rate: int, channels: int):
out = mx.array(np.frombuffer(out, np.int16))
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
chunk_length: Optional[int] = None,
chunk_stride: Optional[int] = None,
):
r"""
Prepare inputs for the EnCodec model.
Args:
raw_audio (mx.array or List[mx.array]): The sequence or batch of
sequences to be processed.
sampling_rate (int): The sampling rate at which the audio waveform
should be digitalized.
chunk_length (int, optional): The model's chunk length.
chunk_stride (int, optional): The model's chunk stride.
"""
if not isinstance(raw_audio, list):
raw_audio = [raw_audio]
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
max_length = max(array.shape[0] for array in raw_audio)
if chunk_length is not None:
max_length += chunk_length - (max_length % chunk_stride)
inputs = []
masks = []
for x in raw_audio:
length = x.shape[0]
mask = mx.ones((length,), dtype=mx.bool_)
difference = max_length - length
if difference > 0:
mask = mx.pad(mask, (0, difference))
x = mx.pad(x, ((0, difference), (0, 0)))
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)
def load(path_or_repo):
"""
Load the model and audo preprocessor.
"""
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = encodec.EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor

View File

@@ -0,0 +1,22 @@
import sys
import time
from pathlib import Path
import mlx.core as mx
from utils import load
text = "folk ballad"
model = load("facebook/musicgen-medium")
max_steps = 100
audio = model.generate(text, max_steps=10)
mx.eval(audio)
tic = time.time()
audio = model.generate(text, max_steps=max_steps)
mx.eval(audio)
toc = time.time()
ms = 1000 * (toc - tic) / max_steps
print(f"Time (ms) per step: {ms:.3f}")

View File

@@ -0,0 +1,29 @@
import time
import torch
from transformers import AutoProcessor, MusicgenForConditionalGeneration
model_name = "facebook/musicgen-medium"
processor = AutoProcessor.from_pretrained(model_name)
model = MusicgenForConditionalGeneration.from_pretrained(model_name).to("mps")
inputs = processor(
text=["folk ballad"],
padding=True,
return_tensors="pt",
)
inputs["input_ids"] = inputs["input_ids"].to("mps")
inputs["attention_mask"] = inputs["attention_mask"].to("mps")
# warmup
audio_values = model.generate(**inputs, max_new_tokens=10)
torch.mps.synchronize()
max_steps = 100
tic = time.time()
audio_values = model.generate(**inputs, max_new_tokens=max_steps)
torch.mps.synchronize()
toc = time.time()
ms = 1000 * (toc - tic) / max_steps
print(f"Time (ms) per step: {ms:.3f}")