mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
add readme and requirements
This commit is contained in:
parent
6a37b5106a
commit
90779df767
49
musicgen/README.md
Normal file
49
musicgen/README.md
Normal file
@ -0,0 +1,49 @@
|
||||
# MusicGen
|
||||
|
||||
An example of Meta's MusicGen model in MLX.[^1] MusicGen is used to generate
|
||||
music from text descriptions.
|
||||
|
||||
### Setup
|
||||
|
||||
Install the requirements:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Optionally install FFmpeg and SciPy for loading and saving audio files,
|
||||
respectively.
|
||||
|
||||
Install [FFmpeg](https://ffmpeg.org/):
|
||||
|
||||
```
|
||||
# on macOS using Homebrew (https://brew.sh/)
|
||||
brew install ffmpeg
|
||||
```
|
||||
|
||||
Install SciPy:
|
||||
|
||||
```
|
||||
pip install scipy
|
||||
```
|
||||
|
||||
### Example
|
||||
|
||||
An example using the model:
|
||||
|
||||
```python
|
||||
import mlx.core as mx
|
||||
from music_gen import MusicGen
|
||||
from utils import save_audio
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = MusicGen.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
audio = model.generate("happy rock")
|
||||
|
||||
# Save the audio as a wave file
|
||||
save_audio("out.wav", audio, model.sampling_rate)
|
||||
```
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2306.05284) and
|
||||
[code](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details.
|
@ -154,11 +154,11 @@ def top_k_sampling(
|
||||
|
||||
Args:
|
||||
logits: The logits from the model's output.
|
||||
top_k: The cumulative probability threshold for top-p filtering.
|
||||
top_k: Sample from the top k logits.
|
||||
temperature: Temperature parameter for softmax distribution reshaping.
|
||||
axis: Axis along which to sample
|
||||
axis: Axis along which to sample.
|
||||
Returns:
|
||||
token selected based on the top-p criterion.
|
||||
token selected based on the top-k criterion.
|
||||
"""
|
||||
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||
probs = mx.softmax(logits * (1 / temperature), axis=axis)
|
||||
@ -177,7 +177,7 @@ def top_k_sampling(
|
||||
|
||||
sorted_token = mx.random.categorical(mx.log(top_probs), axis=axis)
|
||||
token = mx.take_along_axis(
|
||||
sorted_indices, sorted_token[:, :, mx.newaxis], axis=axis
|
||||
sorted_indices, mx.expand_dims(sorted_token, axis), axis=axis
|
||||
)
|
||||
|
||||
return token
|
||||
@ -236,9 +236,8 @@ class MusicGen(nn.Module):
|
||||
x = sum([self.emb[k](audio_tokens[..., k]) for k in range(self.num_codebooks)])
|
||||
|
||||
offset = cache[0].offset if cache[0] is not None else 0
|
||||
offset = mx.full((audio_tokens.shape[0], 1, 1), offset)
|
||||
pos_emb = create_sin_embedding(offset, self.hidden_size)
|
||||
x += pos_emb
|
||||
x += pos_emb.astype(x.dtype)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, conditioning, cache=c)
|
||||
@ -351,7 +350,7 @@ class MusicGen(nn.Module):
|
||||
config.decoder = SimpleNamespace(**config.decoder)
|
||||
|
||||
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
|
||||
weights = {k: mx.array(v.numpy()) for k, v in weights.items()}
|
||||
weights = {k: mx.array(v) for k, v in weights.items()}
|
||||
weights = cls.sanitize(weights)
|
||||
|
||||
model = MusicGen(config)
|
||||
|
6
musicgen/requirements.txt
Normal file
6
musicgen/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
mlx>=0.18
|
||||
numpy
|
||||
huggingface_hub
|
||||
torch
|
||||
transformers
|
||||
scipy
|
Loading…
Reference in New Issue
Block a user