mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

* Add MusicGen model * add benchmarks * change to from_pretrained * symlinks * add readme and requirements * fix readme * readme
32 lines
823 B
Python
32 lines
823 B
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import time
|
|
|
|
import torch
|
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
|
|
|
model_name = "facebook/musicgen-medium"
|
|
processor = AutoProcessor.from_pretrained(model_name)
|
|
model = MusicgenForConditionalGeneration.from_pretrained(model_name).to("mps")
|
|
|
|
inputs = processor(
|
|
text=["folk ballad"],
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs["input_ids"] = inputs["input_ids"].to("mps")
|
|
inputs["attention_mask"] = inputs["attention_mask"].to("mps")
|
|
|
|
# warmup
|
|
audio_values = model.generate(**inputs, max_new_tokens=10)
|
|
torch.mps.synchronize()
|
|
|
|
max_steps = 100
|
|
tic = time.time()
|
|
audio_values = model.generate(**inputs, max_new_tokens=max_steps)
|
|
torch.mps.synchronize()
|
|
toc = time.time()
|
|
|
|
ms = 1000 * (toc - tic) / max_steps
|
|
print(f"Time (ms) per step: {ms:.3f}")
|