add readme and requirements

This commit is contained in:
Alex Barron 2024-10-11 09:46:51 -07:00
parent 6a37b5106a
commit 90779df767
4 changed files with 61 additions and 8 deletions

49
musicgen/README.md Normal file
View File

@ -0,0 +1,49 @@
# MusicGen
An example of Meta's MusicGen model in MLX.[^1] MusicGen is used to generate
music from text descriptions.
### Setup
Install the requirements:
```
pip install -r requirements.txt
```
Optionally install FFmpeg and SciPy for loading and saving audio files,
respectively.
Install [FFmpeg](https://ffmpeg.org/):
```
# on macOS using Homebrew (https://brew.sh/)
brew install ffmpeg
```
Install SciPy:
```
pip install scipy
```
### Example
An example using the model:
```python
import mlx.core as mx
from music_gen import MusicGen
from utils import save_audio
# Load the 48 KHz model and preprocessor.
model, processor = MusicGen.from_pretrained("facebook/musicgen-medium")
audio = model.generate("happy rock")
# Save the audio as a wave file
save_audio("out.wav", audio, model.sampling_rate)
```
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2306.05284) and
[code](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details.

View File

@ -154,11 +154,11 @@ def top_k_sampling(
Args:
logits: The logits from the model's output.
top_k: The cumulative probability threshold for top-p filtering.
top_k: Sample from the top k logits.
temperature: Temperature parameter for softmax distribution reshaping.
axis: Axis along which to sample
axis: Axis along which to sample.
Returns:
token selected based on the top-p criterion.
token selected based on the top-k criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=axis)
@ -177,7 +177,7 @@ def top_k_sampling(
sorted_token = mx.random.categorical(mx.log(top_probs), axis=axis)
token = mx.take_along_axis(
sorted_indices, sorted_token[:, :, mx.newaxis], axis=axis
sorted_indices, mx.expand_dims(sorted_token, axis), axis=axis
)
return token
@ -236,9 +236,8 @@ class MusicGen(nn.Module):
x = sum([self.emb[k](audio_tokens[..., k]) for k in range(self.num_codebooks)])
offset = cache[0].offset if cache[0] is not None else 0
offset = mx.full((audio_tokens.shape[0], 1, 1), offset)
pos_emb = create_sin_embedding(offset, self.hidden_size)
x += pos_emb
x += pos_emb.astype(x.dtype)
for layer, c in zip(self.layers, cache):
x = layer(x, conditioning, cache=c)
@ -351,7 +350,7 @@ class MusicGen(nn.Module):
config.decoder = SimpleNamespace(**config.decoder)
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
weights = {k: mx.array(v.numpy()) for k, v in weights.items()}
weights = {k: mx.array(v) for k, v in weights.items()}
weights = cls.sanitize(weights)
model = MusicGen(config)

View File

@ -0,0 +1,6 @@
mlx>=0.18
numpy
huggingface_hub
torch
transformers
scipy

View File

@ -416,7 +416,6 @@ class T5(nn.Module):
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
print(path)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))