mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-06 00:44:34 +08:00
add benchmarks
This commit is contained in:
@@ -33,7 +33,8 @@ An example using the model:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from utils import load, load_audio, save_audio
|
from encodec import load
|
||||||
|
from utils import load_audio, save_audio
|
||||||
|
|
||||||
# Load the 48 KHz model and preprocessor.
|
# Load the 48 KHz model and preprocessor.
|
||||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from utils import load, load_audio, save_audio
|
from utils import load_audio, save_audio
|
||||||
|
|
||||||
|
from encodec import load
|
||||||
|
|
||||||
# Load the 48 KHz model and preprocessor.
|
# Load the 48 KHz model and preprocessor.
|
||||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||||
|
@@ -3,9 +3,9 @@
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from datasets import Audio, load_dataset
|
|
||||||
from transformers import AutoProcessor, EncodecModel
|
from transformers import AutoProcessor, EncodecModel
|
||||||
from utils import load, load_audio, preprocess_audio
|
|
||||||
|
from encodec import load, preprocess_audio
|
||||||
|
|
||||||
|
|
||||||
def compare_processors():
|
def compare_processors():
|
||||||
|
@@ -1,16 +1,7 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import functools
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
import encodec
|
|
||||||
|
|
||||||
|
|
||||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||||
@@ -59,71 +50,3 @@ def load_audio(file: str, sampling_rate: int, channels: int):
|
|||||||
|
|
||||||
out = mx.array(np.frombuffer(out, np.int16))
|
out = mx.array(np.frombuffer(out, np.int16))
|
||||||
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
|
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
|
||||||
|
|
||||||
|
|
||||||
def preprocess_audio(
|
|
||||||
raw_audio: Union[mx.array, List[mx.array]],
|
|
||||||
sampling_rate: int = 24000,
|
|
||||||
chunk_length: Optional[int] = None,
|
|
||||||
chunk_stride: Optional[int] = None,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Prepare inputs for the EnCodec model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
|
||||||
sequences to be processed.
|
|
||||||
sampling_rate (int): The sampling rate at which the audio waveform
|
|
||||||
should be digitalized.
|
|
||||||
chunk_length (int, optional): The model's chunk length.
|
|
||||||
chunk_stride (int, optional): The model's chunk stride.
|
|
||||||
"""
|
|
||||||
if not isinstance(raw_audio, list):
|
|
||||||
raw_audio = [raw_audio]
|
|
||||||
|
|
||||||
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
|
||||||
|
|
||||||
max_length = max(array.shape[0] for array in raw_audio)
|
|
||||||
if chunk_length is not None:
|
|
||||||
max_length += chunk_length - (max_length % chunk_stride)
|
|
||||||
|
|
||||||
inputs = []
|
|
||||||
masks = []
|
|
||||||
for x in raw_audio:
|
|
||||||
length = x.shape[0]
|
|
||||||
mask = mx.ones((length,), dtype=mx.bool_)
|
|
||||||
difference = max_length - length
|
|
||||||
if difference > 0:
|
|
||||||
mask = mx.pad(mask, (0, difference))
|
|
||||||
x = mx.pad(x, ((0, difference), (0, 0)))
|
|
||||||
inputs.append(x)
|
|
||||||
masks.append(mask)
|
|
||||||
return mx.stack(inputs), mx.stack(masks)
|
|
||||||
|
|
||||||
|
|
||||||
def load(path_or_repo):
|
|
||||||
"""
|
|
||||||
Load the model and audo preprocessor.
|
|
||||||
"""
|
|
||||||
path = Path(path_or_repo)
|
|
||||||
if not path.exists():
|
|
||||||
path = Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=path_or_repo,
|
|
||||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(path / "config.json", "r") as f:
|
|
||||||
config = SimpleNamespace(**json.load(f))
|
|
||||||
|
|
||||||
model = encodec.EncodecModel(config)
|
|
||||||
model.load_weights(str(path / "model.safetensors"))
|
|
||||||
processor = functools.partial(
|
|
||||||
preprocess_audio,
|
|
||||||
sampling_rate=config.sampling_rate,
|
|
||||||
chunk_length=model.chunk_length,
|
|
||||||
chunk_stride=model.chunk_stride,
|
|
||||||
)
|
|
||||||
mx.eval(model)
|
|
||||||
return model, processor
|
|
||||||
|
22
musicgen/benchmarks/bench_mx.py
Normal file
22
musicgen/benchmarks/bench_mx.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from utils import load
|
||||||
|
|
||||||
|
text = "folk ballad"
|
||||||
|
model = load("facebook/musicgen-medium")
|
||||||
|
|
||||||
|
max_steps = 100
|
||||||
|
|
||||||
|
audio = model.generate(text, max_steps=10)
|
||||||
|
mx.eval(audio)
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
audio = model.generate(text, max_steps=max_steps)
|
||||||
|
mx.eval(audio)
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
ms = 1000 * (toc - tic) / max_steps
|
||||||
|
print(f"Time (ms) per step: {ms:.3f}")
|
29
musicgen/benchmarks/bench_pt.py
Normal file
29
musicgen/benchmarks/bench_pt.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||||
|
|
||||||
|
model_name = "facebook/musicgen-medium"
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name)
|
||||||
|
model = MusicgenForConditionalGeneration.from_pretrained(model_name).to("mps")
|
||||||
|
|
||||||
|
inputs = processor(
|
||||||
|
text=["folk ballad"],
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
inputs["input_ids"] = inputs["input_ids"].to("mps")
|
||||||
|
inputs["attention_mask"] = inputs["attention_mask"].to("mps")
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
audio_values = model.generate(**inputs, max_new_tokens=10)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
|
||||||
|
max_steps = 100
|
||||||
|
tic = time.time()
|
||||||
|
audio_values = model.generate(**inputs, max_new_tokens=max_steps)
|
||||||
|
torch.mps.synchronize()
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
ms = 1000 * (toc - tic) / max_steps
|
||||||
|
print(f"Time (ms) per step: {ms:.3f}")
|
Reference in New Issue
Block a user