* Add MusicGen model

* add benchmarks

* change to from_pretrained

* symlinks

* add readme and requirements

* fix readme

* readme
This commit is contained in:
Alex Barron
2024-10-11 10:16:20 -07:00
committed by GitHub
parent 4360e7ccec
commit d72fdeb4ee
19 changed files with 722 additions and 245 deletions

23
musicgen/generate.py Normal file
View File

@@ -0,0 +1,23 @@
# Copyright © 2024 Apple Inc.
import argparse
from utils import save_audio
from musicgen import MusicGen
def main(text: str, output_path: str, model_name: str, max_steps: int):
model = MusicGen.from_pretrained(model_name)
audio = model.generate(text, max_steps=max_steps)
save_audio(output_path, audio, model.sampling_rate)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=False, default="facebook/musicgen-medium")
parser.add_argument("--text", required=False, default="happy rock")
parser.add_argument("--output-path", required=False, default="0.wav")
parser.add_argument("--max-steps", required=False, default=500, type=int)
args = parser.parse_args()
main(args.text, args.output_path, args.model, args.max_steps)