mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-30 01:58:09 +08:00
MusicGen (#1020)
* Add MusicGen model * add benchmarks * change to from_pretrained * symlinks * add readme and requirements * fix readme * readme
This commit is contained in:
28
musicgen/benchmarks/bench_mx.py
Normal file
28
musicgen/benchmarks/bench_mx.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
cur_path = Path(__file__).parents[1].resolve()
|
||||
sys.path.append(str(cur_path))
|
||||
|
||||
from musicgen import MusicGen
|
||||
|
||||
text = "folk ballad"
|
||||
model = MusicGen.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
max_steps = 100
|
||||
|
||||
audio = model.generate(text, max_steps=10)
|
||||
mx.eval(audio)
|
||||
|
||||
tic = time.time()
|
||||
audio = model.generate(text, max_steps=max_steps)
|
||||
mx.eval(audio)
|
||||
toc = time.time()
|
||||
|
||||
ms = 1000 * (toc - tic) / max_steps
|
||||
print(f"Time (ms) per step: {ms:.3f}")
|
||||
31
musicgen/benchmarks/bench_pt.py
Normal file
31
musicgen/benchmarks/bench_pt.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
model_name = "facebook/musicgen-medium"
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(model_name).to("mps")
|
||||
|
||||
inputs = processor(
|
||||
text=["folk ballad"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs["input_ids"] = inputs["input_ids"].to("mps")
|
||||
inputs["attention_mask"] = inputs["attention_mask"].to("mps")
|
||||
|
||||
# warmup
|
||||
audio_values = model.generate(**inputs, max_new_tokens=10)
|
||||
torch.mps.synchronize()
|
||||
|
||||
max_steps = 100
|
||||
tic = time.time()
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_steps)
|
||||
torch.mps.synchronize()
|
||||
toc = time.time()
|
||||
|
||||
ms = 1000 * (toc - tic) / max_steps
|
||||
print(f"Time (ms) per step: {ms:.3f}")
|
||||
Reference in New Issue
Block a user