mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
MusicGen (#1020)
* Add MusicGen model * add benchmarks * change to from_pretrained * symlinks * add readme and requirements * fix readme * readme
This commit is contained in:
parent
4360e7ccec
commit
d72fdeb4ee
@ -33,13 +33,14 @@ An example using the model:
|
||||
|
||||
```python
|
||||
import mlx.core as mx
|
||||
from utils import load, load_audio, save_audio
|
||||
from encodec import EncodecModel
|
||||
from utils import load_audio, save_audio
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
# Load an audio file
|
||||
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
|
||||
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
|
||||
|
||||
# Preprocess the audio (this can also be a list of arrays for batched
|
||||
# processing).
|
||||
|
@ -3,9 +3,10 @@
|
||||
import time
|
||||
|
||||
import mlx.core as mx
|
||||
from utils import load
|
||||
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
from encodec import EncodecModel
|
||||
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
audio = mx.random.uniform(shape=(288000, 2))
|
||||
feats, mask = processor(audio)
|
||||
|
@ -10,7 +10,6 @@ from typing import Any, Dict, Union
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
import encodec
|
||||
|
||||
|
@ -1,7 +1,10 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import functools
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
@ -669,3 +672,70 @@ class EncodecModel(nn.Module):
|
||||
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
|
||||
audio_values = audio_values[:, : padding_mask.shape[1]]
|
||||
return audio_values
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_or_repo: str):
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
model = EncodecModel(config)
|
||||
model.load_weights(str(path / "model.safetensors"))
|
||||
processor = functools.partial(
|
||||
preprocess_audio,
|
||||
sampling_rate=config.sampling_rate,
|
||||
chunk_length=model.chunk_length,
|
||||
chunk_stride=model.chunk_stride,
|
||||
)
|
||||
mx.eval(model)
|
||||
return model, processor
|
||||
|
||||
|
||||
def preprocess_audio(
|
||||
raw_audio: Union[mx.array, List[mx.array]],
|
||||
sampling_rate: int = 24000,
|
||||
chunk_length: Optional[int] = None,
|
||||
chunk_stride: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Prepare inputs for the EnCodec model.
|
||||
|
||||
Args:
|
||||
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
||||
sequences to be processed.
|
||||
sampling_rate (int): The sampling rate at which the audio waveform
|
||||
should be digitalized.
|
||||
chunk_length (int, optional): The model's chunk length.
|
||||
chunk_stride (int, optional): The model's chunk stride.
|
||||
"""
|
||||
if not isinstance(raw_audio, list):
|
||||
raw_audio = [raw_audio]
|
||||
|
||||
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
||||
|
||||
max_length = max(array.shape[0] for array in raw_audio)
|
||||
if chunk_length is not None:
|
||||
max_length += chunk_length - (max_length % chunk_stride)
|
||||
|
||||
inputs = []
|
||||
masks = []
|
||||
for x in raw_audio:
|
||||
length = x.shape[0]
|
||||
mask = mx.ones((length,), dtype=mx.bool_)
|
||||
difference = max_length - length
|
||||
if difference > 0:
|
||||
mask = mx.pad(mask, (0, difference))
|
||||
x = mx.pad(x, ((0, difference), (0, 0)))
|
||||
inputs.append(x)
|
||||
masks.append(mask)
|
||||
return mx.stack(inputs), mx.stack(masks)
|
||||
|
@ -1,10 +1,12 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
from utils import load, load_audio, save_audio
|
||||
from utils import load_audio, save_audio
|
||||
|
||||
from encodec import EncodecModel
|
||||
|
||||
# Load the 48 KHz model and preprocessor.
|
||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
||||
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
# Load an audio file
|
||||
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
||||
|
@ -3,9 +3,10 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Audio, load_dataset
|
||||
from transformers import AutoProcessor, EncodecModel
|
||||
from utils import load, load_audio, preprocess_audio
|
||||
from transformers import AutoProcessor
|
||||
from transformers import EncodecModel as PTEncodecModel
|
||||
|
||||
from encodec import EncodecModel, preprocess_audio
|
||||
|
||||
|
||||
def compare_processors():
|
||||
@ -30,8 +31,8 @@ def compare_processors():
|
||||
|
||||
|
||||
def compare_models():
|
||||
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||
mx_model, _ = load("mlx-community/encodec-48khz-float32")
|
||||
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||
|
||||
np.random.seed(0)
|
||||
audio_length = 190560
|
||||
|
@ -1,16 +1,7 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import functools
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import encodec
|
||||
|
||||
|
||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||
@ -59,71 +50,3 @@ def load_audio(file: str, sampling_rate: int, channels: int):
|
||||
|
||||
out = mx.array(np.frombuffer(out, np.int16))
|
||||
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
|
||||
|
||||
|
||||
def preprocess_audio(
|
||||
raw_audio: Union[mx.array, List[mx.array]],
|
||||
sampling_rate: int = 24000,
|
||||
chunk_length: Optional[int] = None,
|
||||
chunk_stride: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Prepare inputs for the EnCodec model.
|
||||
|
||||
Args:
|
||||
raw_audio (mx.array or List[mx.array]): The sequence or batch of
|
||||
sequences to be processed.
|
||||
sampling_rate (int): The sampling rate at which the audio waveform
|
||||
should be digitalized.
|
||||
chunk_length (int, optional): The model's chunk length.
|
||||
chunk_stride (int, optional): The model's chunk stride.
|
||||
"""
|
||||
if not isinstance(raw_audio, list):
|
||||
raw_audio = [raw_audio]
|
||||
|
||||
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
|
||||
|
||||
max_length = max(array.shape[0] for array in raw_audio)
|
||||
if chunk_length is not None:
|
||||
max_length += chunk_length - (max_length % chunk_stride)
|
||||
|
||||
inputs = []
|
||||
masks = []
|
||||
for x in raw_audio:
|
||||
length = x.shape[0]
|
||||
mask = mx.ones((length,), dtype=mx.bool_)
|
||||
difference = max_length - length
|
||||
if difference > 0:
|
||||
mask = mx.pad(mask, (0, difference))
|
||||
x = mx.pad(x, ((0, difference), (0, 0)))
|
||||
inputs.append(x)
|
||||
masks.append(mask)
|
||||
return mx.stack(inputs), mx.stack(masks)
|
||||
|
||||
|
||||
def load(path_or_repo):
|
||||
"""
|
||||
Load the model and audo preprocessor.
|
||||
"""
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
model = encodec.EncodecModel(config)
|
||||
model.load_weights(str(path / "model.safetensors"))
|
||||
processor = functools.partial(
|
||||
preprocess_audio,
|
||||
sampling_rate=config.sampling_rate,
|
||||
chunk_length=model.chunk_length,
|
||||
chunk_stride=model.chunk_stride,
|
||||
)
|
||||
mx.eval(model)
|
||||
return model, processor
|
||||
|
31
musicgen/README.md
Normal file
31
musicgen/README.md
Normal file
@ -0,0 +1,31 @@
|
||||
# MusicGen
|
||||
|
||||
An example of Meta's MusicGen model in MLX.[^1] MusicGen is used to generate
|
||||
music from text descriptions.
|
||||
|
||||
### Setup
|
||||
|
||||
Install the requirements:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Example
|
||||
|
||||
An example using the model:
|
||||
|
||||
```python
|
||||
import mlx.core as mx
|
||||
from music_gen import MusicGen
|
||||
from utils import save_audio
|
||||
|
||||
model = MusicGen.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
audio = model.generate("happy rock")
|
||||
|
||||
save_audio("out.wav", audio, model.sampling_rate)
|
||||
```
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2306.05284) and
|
||||
[code](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details.
|
28
musicgen/benchmarks/bench_mx.py
Normal file
28
musicgen/benchmarks/bench_mx.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
cur_path = Path(__file__).parents[1].resolve()
|
||||
sys.path.append(str(cur_path))
|
||||
|
||||
from musicgen import MusicGen
|
||||
|
||||
text = "folk ballad"
|
||||
model = MusicGen.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
max_steps = 100
|
||||
|
||||
audio = model.generate(text, max_steps=10)
|
||||
mx.eval(audio)
|
||||
|
||||
tic = time.time()
|
||||
audio = model.generate(text, max_steps=max_steps)
|
||||
mx.eval(audio)
|
||||
toc = time.time()
|
||||
|
||||
ms = 1000 * (toc - tic) / max_steps
|
||||
print(f"Time (ms) per step: {ms:.3f}")
|
31
musicgen/benchmarks/bench_pt.py
Normal file
31
musicgen/benchmarks/bench_pt.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
model_name = "facebook/musicgen-medium"
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(model_name).to("mps")
|
||||
|
||||
inputs = processor(
|
||||
text=["folk ballad"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs["input_ids"] = inputs["input_ids"].to("mps")
|
||||
inputs["attention_mask"] = inputs["attention_mask"].to("mps")
|
||||
|
||||
# warmup
|
||||
audio_values = model.generate(**inputs, max_new_tokens=10)
|
||||
torch.mps.synchronize()
|
||||
|
||||
max_steps = 100
|
||||
tic = time.time()
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_steps)
|
||||
torch.mps.synchronize()
|
||||
toc = time.time()
|
||||
|
||||
ms = 1000 * (toc - tic) / max_steps
|
||||
print(f"Time (ms) per step: {ms:.3f}")
|
1
musicgen/encodec.py
Symbolic link
1
musicgen/encodec.py
Symbolic link
@ -0,0 +1 @@
|
||||
../encodec/encodec.py
|
23
musicgen/generate.py
Normal file
23
musicgen/generate.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
|
||||
from utils import save_audio
|
||||
|
||||
from musicgen import MusicGen
|
||||
|
||||
|
||||
def main(text: str, output_path: str, model_name: str, max_steps: int):
|
||||
model = MusicGen.from_pretrained(model_name)
|
||||
audio = model.generate(text, max_steps=max_steps)
|
||||
save_audio(output_path, audio, model.sampling_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", required=False, default="facebook/musicgen-medium")
|
||||
parser.add_argument("--text", required=False, default="happy rock")
|
||||
parser.add_argument("--output-path", required=False, default="0.wav")
|
||||
parser.add_argument("--max-steps", required=False, default=500, type=int)
|
||||
args = parser.parse_args()
|
||||
main(args.text, args.output_path, args.model, args.max_steps)
|
358
musicgen/musicgen.py
Normal file
358
musicgen/musicgen.py
Normal file
@ -0,0 +1,358 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from encodec import EncodecModel
|
||||
from t5 import T5
|
||||
|
||||
|
||||
class TextConditioner(nn.Module):
|
||||
def __init__(self, t5_name, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self._t5, self.tokenizer = T5.from_pretrained(t5_name)
|
||||
self.output_proj = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def __call__(self, text):
|
||||
x = self.tokenizer.encode(text)
|
||||
x = self._t5.encode(x)
|
||||
return self.output_proj(x)
|
||||
|
||||
|
||||
class KVCache:
|
||||
def __init__(self, head_dim, n_kv_heads):
|
||||
self.n_kv_heads = n_kv_heads
|
||||
if isinstance(head_dim, int):
|
||||
self.k_head_dim = self.v_head_dim = head_dim
|
||||
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
||||
self.k_head_dim, self.v_head_dim = head_dim
|
||||
else:
|
||||
raise ValueError("head_dim must be an int or a tuple of two ints")
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
B = keys.shape[0]
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
|
||||
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = self.keys[..., :prev, :]
|
||||
self.values = self.values[..., :prev, :]
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
self.keys[..., prev : self.offset, :] = keys
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, dim, n_heads):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = n_heads
|
||||
|
||||
head_dim = dim // n_heads
|
||||
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Linear(dim, dim, bias=False)
|
||||
self.k_proj = nn.Linear(dim, dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, dim, bias=False)
|
||||
self.out_proj = nn.Linear(dim, dim, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
queries: mx.array,
|
||||
keys: mx.array,
|
||||
values: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
B, L_q, D = queries.shape
|
||||
L_k = keys.shape[1]
|
||||
|
||||
queries, keys, values = (
|
||||
self.q_proj(queries),
|
||||
self.k_proj(keys),
|
||||
self.v_proj(values),
|
||||
)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L_q, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L_k, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L_k, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
keys, values = cache.update_and_fetch(keys, values)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L_q, -1)
|
||||
return self.out_proj(output)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.decoder.num_attention_heads
|
||||
self.hidden_size = config.decoder.hidden_size
|
||||
self.self_attn = MultiHeadAttention(self.hidden_size, self.num_attention_heads)
|
||||
self.cross_attn = MultiHeadAttention(self.hidden_size, self.num_attention_heads)
|
||||
self.linear1 = nn.Linear(self.hidden_size, config.decoder.ffn_dim, bias=False)
|
||||
self.linear2 = nn.Linear(config.decoder.ffn_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.norm1 = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
||||
self.norm_cross = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
||||
self.norm2 = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
conditioning: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
) -> mx.array:
|
||||
xn = self.norm1(x)
|
||||
x += self.self_attn(xn, xn, xn, mask, cache)
|
||||
xn = self.norm_cross(x)
|
||||
x += self.cross_attn(xn, conditioning, conditioning, mask)
|
||||
xn = self.norm2(x)
|
||||
x += self.linear2(nn.gelu(self.linear1(xn)))
|
||||
return x
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def top_k_sampling(
|
||||
logits: mx.array, top_k: float, temperature: float, axis: int = -1
|
||||
) -> mx.array:
|
||||
"""
|
||||
Apply top-k sampling to logits.
|
||||
|
||||
Args:
|
||||
logits: The logits from the model's output.
|
||||
top_k: Sample from the top k logits.
|
||||
temperature: Temperature parameter for softmax distribution reshaping.
|
||||
axis: Axis along which to sample.
|
||||
Returns:
|
||||
token selected based on the top-k criterion.
|
||||
"""
|
||||
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||
probs = mx.softmax(logits * (1 / temperature), axis=axis)
|
||||
|
||||
# sort probs in ascending order
|
||||
sorted_indices = mx.argsort(probs, axis=axis)
|
||||
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)
|
||||
prob_threshold = mx.take(sorted_probs, mx.array(-top_k), axis=axis)
|
||||
|
||||
# select the top K tokens in probability
|
||||
top_probs = mx.where(
|
||||
sorted_probs > prob_threshold,
|
||||
sorted_probs,
|
||||
0,
|
||||
)
|
||||
|
||||
sorted_token = mx.random.categorical(mx.log(top_probs), axis=axis)
|
||||
token = mx.take_along_axis(
|
||||
sorted_indices, mx.expand_dims(sorted_token, axis), axis=axis
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def create_sin_embedding(positions: mx.array, dim: int, max_period: float = 10000):
|
||||
assert dim % 2 == 0
|
||||
half_dim = dim // 2
|
||||
adim = mx.arange(half_dim).reshape(1, 1, -1)
|
||||
phase = positions / (max_period ** (adim / (half_dim - 1)))
|
||||
return mx.concatenate([mx.cos(phase), mx.sin(phase)], axis=-1)
|
||||
|
||||
|
||||
class MusicGen(nn.Module):
|
||||
def __init__(self, config):
|
||||
self.num_codebooks = config.decoder.num_codebooks
|
||||
self.codebook_size = config.audio_encoder.codebook_size
|
||||
self.bos_token_id = config.decoder.bos_token_id
|
||||
self.hidden_size = config.decoder.hidden_size
|
||||
self.num_attention_heads = config.decoder.num_attention_heads
|
||||
self.sampling_rate = config.audio_encoder.sampling_rate
|
||||
|
||||
self.text_conditioner = TextConditioner(
|
||||
config.text_encoder._name_or_path,
|
||||
config.text_encoder.d_model,
|
||||
self.hidden_size,
|
||||
)
|
||||
self.emb = [
|
||||
nn.Embedding(self.codebook_size + 1, self.hidden_size)
|
||||
for _ in range(self.num_codebooks)
|
||||
]
|
||||
self.layers = [
|
||||
TransformerBlock(config) for _ in range(config.decoder.num_hidden_layers)
|
||||
]
|
||||
self.out_norm = nn.LayerNorm(self.hidden_size, eps=1e-5)
|
||||
self.linears = [
|
||||
nn.Linear(self.hidden_size, self.codebook_size, bias=False)
|
||||
for _ in range(self.num_codebooks)
|
||||
]
|
||||
encodec_name = config.audio_encoder._name_or_path.split("/")[-1]
|
||||
encodec_name = encodec_name.replace("_", "-")
|
||||
self._audio_decoder, _ = EncodecModel.from_pretrained(
|
||||
f"mlx-community/{encodec_name}-float32"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
audio_tokens: mx.array,
|
||||
conditioning: mx.array,
|
||||
cache: list[KVCache] = None,
|
||||
):
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
x = sum([self.emb[k](audio_tokens[..., k]) for k in range(self.num_codebooks)])
|
||||
|
||||
offset = cache[0].offset if cache[0] is not None else 0
|
||||
pos_emb = create_sin_embedding(offset, self.hidden_size)
|
||||
x += pos_emb.astype(x.dtype)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
x = layer(x, conditioning, cache=c)
|
||||
|
||||
x = self.out_norm(x)
|
||||
x = mx.stack([self.linears[k](x) for k in range(self.num_codebooks)], axis=-1)
|
||||
return x
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
max_steps: int = 200,
|
||||
top_k: int = 250,
|
||||
temp: float = 1.0,
|
||||
guidance_coef: float = 3.0,
|
||||
) -> mx.array:
|
||||
"""
|
||||
Generates a waveform conditioned on `text`.
|
||||
|
||||
Args:
|
||||
text (str): The text to condition generation on.
|
||||
max_steps (int): Max steps to generate.
|
||||
top_k (int): Top k used in sampling.
|
||||
temp (float): Sampling softmax temperature.
|
||||
guidance_coef (float): Classifier free guidance coefficent.
|
||||
Used to combine conditional and unconditional logits.
|
||||
|
||||
Returns:
|
||||
An mx.array of audio samples of shape ``(num_samples,)``.
|
||||
"""
|
||||
# Assuming no audio prompt we start with all bos token for the codebooks
|
||||
audio_shape = (1, max_steps + 1, self.num_codebooks)
|
||||
audio_seq = mx.full(audio_shape, self.bos_token_id)
|
||||
|
||||
text_tokens = self.text_conditioner(text)
|
||||
# Compute conditional and unconditional logits in one batch
|
||||
text_tokens = mx.concatenate([text_tokens, mx.zeros_like(text_tokens)], axis=0)
|
||||
|
||||
head_dim = self.hidden_size // self.num_attention_heads
|
||||
cache = [
|
||||
KVCache(head_dim, self.num_attention_heads) for _ in range(len(self.layers))
|
||||
]
|
||||
for offset in tqdm(range(max_steps)):
|
||||
audio_input = mx.tile(audio_seq[:, offset : offset + 1], [2, 1, 1])
|
||||
audio_logits = self(audio_input, text_tokens, cache)
|
||||
cond_logits, uncond_logits = audio_logits[:1], audio_logits[1:2]
|
||||
audio_logits = uncond_logits + (cond_logits - uncond_logits) * guidance_coef
|
||||
audio_tokens = top_k_sampling(audio_logits, top_k, temp, axis=-2)
|
||||
# "delay" pattern
|
||||
audio_tokens[..., offset + 1 :] = self.bos_token_id
|
||||
audio_tokens[..., : -max_steps + offset] = self.bos_token_id
|
||||
audio_seq[:, offset + 1 : offset + 2] = audio_tokens
|
||||
mx.eval(audio_seq)
|
||||
|
||||
# Undo delay
|
||||
for i in range(self.num_codebooks):
|
||||
audio_seq[:, : -self.num_codebooks, i] = audio_seq[
|
||||
:, i : -self.num_codebooks + i, i
|
||||
]
|
||||
audio_seq = audio_seq[:, 1 : -self.num_codebooks + 1]
|
||||
|
||||
audio_seq = mx.swapaxes(audio_seq, -1, -2)[:, mx.newaxis]
|
||||
audio = self._audio_decoder.decode(audio_seq, audio_scales=[None])
|
||||
return audio[0]
|
||||
|
||||
@classmethod
|
||||
def sanitize(cls, weights):
|
||||
out_weights = {}
|
||||
for k, arr in weights.items():
|
||||
if k.startswith("transformer."):
|
||||
k = k[len("transformer.") :]
|
||||
|
||||
if "cross_attention" in k:
|
||||
k = k.replace("cross_attention", "cross_attn")
|
||||
|
||||
if "condition_provider" in k:
|
||||
k = k.replace(
|
||||
"condition_provider.conditioners.description", "text_conditioner"
|
||||
)
|
||||
|
||||
if "in_proj_weight" in k:
|
||||
dim = arr.shape[0] // 3
|
||||
name = "in_proj_weight"
|
||||
out_weights[k.replace(name, "q_proj.weight")] = arr[:dim]
|
||||
out_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2]
|
||||
out_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :]
|
||||
continue
|
||||
|
||||
out_weights[k] = arr
|
||||
return out_weights
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path_or_repo: str):
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "state_dict.bin"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
config.text_encoder = SimpleNamespace(**config.text_encoder)
|
||||
config.audio_encoder = SimpleNamespace(**config.audio_encoder)
|
||||
config.decoder = SimpleNamespace(**config.decoder)
|
||||
|
||||
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
|
||||
weights = {k: mx.array(v) for k, v in weights.items()}
|
||||
weights = cls.sanitize(weights)
|
||||
|
||||
model = MusicGen(config)
|
||||
model.load_weights(list(weights.items()))
|
||||
return model
|
6
musicgen/requirements.txt
Normal file
6
musicgen/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
mlx>=0.18
|
||||
numpy
|
||||
huggingface_hub
|
||||
torch
|
||||
transformers
|
||||
scipy
|
1
musicgen/t5.py
Symbolic link
1
musicgen/t5.py
Symbolic link
@ -0,0 +1 @@
|
||||
../t5/t5.py
|
15
musicgen/utils.py
Normal file
15
musicgen/utils.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||
"""
|
||||
Save audio to a wave (.wav) file.
|
||||
"""
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
audio = mx.clip(audio, -1, 1)
|
||||
audio = (audio * 32767).astype(mx.int16)
|
||||
write(file, sampling_rate, np.array(audio))
|
40
t5/README.md
40
t5/README.md
@ -7,31 +7,6 @@ tasks by prepending task-specific prefixes to the input, e.g.:
|
||||
|
||||
This example also supports the FLAN-T5 models variants.[^2]
|
||||
|
||||
## Setup
|
||||
|
||||
Download and convert the model:
|
||||
|
||||
```sh
|
||||
python convert.py --model <model>
|
||||
```
|
||||
|
||||
This will make the `<model>.npz` file which MLX can read.
|
||||
|
||||
The `<model>` can be any of the following:
|
||||
|
||||
| Model Name | Model Size |
|
||||
| ---------- | ----------
|
||||
| t5-small | 60 million |
|
||||
| t5-base | 220 million |
|
||||
| t5-large | 770 million |
|
||||
| t5-3b | 3 billion |
|
||||
| t5-11b | 11 billion |
|
||||
|
||||
The FLAN variants can be specified with `google/flan-t5-small`,
|
||||
`google/flan-t5-base`, etc. See the [Hugging Face
|
||||
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
|
||||
complete list of models.
|
||||
|
||||
## Generate
|
||||
|
||||
Generate text with:
|
||||
@ -48,6 +23,21 @@ To see a list of options run:
|
||||
python t5.py --help
|
||||
```
|
||||
|
||||
The `<model>` can be any of the following:
|
||||
|
||||
| Model Name | Model Size |
|
||||
| ---------- | ----------
|
||||
| t5-small | 60 million |
|
||||
| t5-base | 220 million |
|
||||
| t5-large | 770 million |
|
||||
| t5-3b | 3 billion |
|
||||
| t5-11b | 11 billion |
|
||||
|
||||
The FLAN variants can be specified with `google/flan-t5-small`,
|
||||
`google/flan-t5-base`, etc. See the [Hugging Face
|
||||
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
|
||||
complete list of models.
|
||||
|
||||
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
|
||||
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
|
||||
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).
|
||||
|
@ -1,75 +0,0 @@
|
||||
import numpy as np
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
ENCODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
DECODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".self_attention."),
|
||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||
(".layer.2.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
if key.startswith("encoder."):
|
||||
for old, new in ENCODER_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
elif key.startswith("decoder."):
|
||||
for old, new in DECODER_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
return key
|
||||
|
||||
|
||||
def convert(model_name, dtype):
|
||||
dtype = getattr(np, dtype)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||
weights = {
|
||||
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
|
||||
}
|
||||
file_name = model_name.replace("/", "-")
|
||||
print(f"Saving weights to {file_name}.npz")
|
||||
np.savez(f"{file_name}.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="Name of the T5 model.",
|
||||
default="t5-small",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="The model data type.",
|
||||
type=str,
|
||||
choices=["float16", "float32"],
|
||||
default="float32",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(args.model, args.dtype)
|
181
t5/t5.py
181
t5/t5.py
@ -1,12 +1,45 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from time import perf_counter_ns
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from transformers import AutoTokenizer, T5Config
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, config, model_name):
|
||||
self._decoder_start_id = config.decoder_start_token_id
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
legacy=False,
|
||||
model_max_length=getattr(config, "n_positions", 512),
|
||||
)
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def decoder_start_id(self) -> int:
|
||||
return self._decoder_start_id
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array(
|
||||
self._tokenizer(
|
||||
s,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||
|
||||
|
||||
def _relative_position_bucket(
|
||||
@ -60,10 +93,10 @@ def _relative_position_bucket(
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, config: T5Config, bidirectional: bool):
|
||||
def __init__(self, config, bidirectional: bool):
|
||||
self.bidirectional = bidirectional
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.max_distance = getattr(config, "relative_attention_max_distance", 128)
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(
|
||||
config.relative_attention_num_buckets, config.num_heads
|
||||
@ -91,7 +124,7 @@ class RelativePositionBias(nn.Module):
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
inner_dim = config.d_kv * config.num_heads
|
||||
self.num_heads = config.num_heads
|
||||
@ -135,17 +168,21 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
|
||||
class DenseActivation(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.gated = config.feed_forward_proj.startswith("gated")
|
||||
self.gated = hasattr(config, "feed_forward_proj")
|
||||
activation = (
|
||||
"relu"
|
||||
if not self.gated
|
||||
else config.feed_forward_proj.removeprefix("gated-")
|
||||
)
|
||||
if self.gated:
|
||||
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
else:
|
||||
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||
if activation == "relu":
|
||||
self.act = nn.relu
|
||||
elif activation == "gelu":
|
||||
@ -166,7 +203,7 @@ class DenseActivation(nn.Module):
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
@ -184,7 +221,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
@ -200,7 +237,7 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
@ -233,7 +270,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||
@ -262,7 +299,7 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs):
|
||||
@ -270,11 +307,11 @@ class OutputHead(nn.Module):
|
||||
|
||||
|
||||
class T5(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
self.decoder = TransformerDecoder(config)
|
||||
self.tie_word_embeddings = config.tie_word_embeddings
|
||||
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
|
||||
if not self.tie_word_embeddings:
|
||||
self.lm_head = OutputHead(config)
|
||||
self.model_dim = config.d_model
|
||||
@ -313,36 +350,82 @@ class T5(nn.Module):
|
||||
):
|
||||
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
||||
|
||||
@classmethod
|
||||
def sanitize(cls, weights):
|
||||
shared_replacement_patterns = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, config: T5Config):
|
||||
self._decoder_start_id = config.decoder_start_token_id
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model,
|
||||
legacy=False,
|
||||
model_max_length=getattr(config, "n_positions", 512),
|
||||
encoder_replacement_patterns = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
decoder_replacement_patterns = [
|
||||
(".layer.0.SelfAttention.", ".self_attention."),
|
||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||
(".layer.2.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
ignored_keys = [
|
||||
"decoder.layers.0.cross_attention.relative_attention_bias.weight"
|
||||
]
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
for old, new in shared_replacement_patterns:
|
||||
key = key.replace(old, new)
|
||||
if key.startswith("encoder."):
|
||||
for old, new in encoder_replacement_patterns:
|
||||
key = key.replace(old, new)
|
||||
elif key.startswith("decoder."):
|
||||
for old, new in decoder_replacement_patterns:
|
||||
key = key.replace(old, new)
|
||||
return key
|
||||
|
||||
weights = {replace_key(k): v for k, v in weights.items()}
|
||||
for key in ignored_keys:
|
||||
if key in weights:
|
||||
del weights[key]
|
||||
return weights
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16
|
||||
) -> tuple["T5", Tokenizer]:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.eos_token_id
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
@property
|
||||
def decoder_start_id(self) -> int:
|
||||
return self._decoder_start_id
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array(
|
||||
self._tokenizer(
|
||||
s,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||
model = T5(config)
|
||||
weights = mx.load(str(path / "model.safetensors"))
|
||||
weights = cls.sanitize(weights)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
model.load_weights(list(weights.items()))
|
||||
return model, Tokenizer(config, "t5-base")
|
||||
|
||||
|
||||
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
|
||||
@ -363,19 +446,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float]
|
||||
yield y.squeeze()
|
||||
|
||||
|
||||
def load_model(model_name: str, dtype: str = "float16"):
|
||||
config = T5Config.from_pretrained(args.model)
|
||||
dtype = getattr(mx, dtype)
|
||||
model = T5(config)
|
||||
file_name = model_name.replace("/", "-")
|
||||
weights = mx.load(f"{file_name}.npz")
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
return model, Tokenizer(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="T5 Inference script")
|
||||
parser.add_argument(
|
||||
@ -421,7 +491,8 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model(args.model, args.dtype)
|
||||
dtype = getattr(mx, args.dtype)
|
||||
model, tokenizer = T5.from_pretrained(args.model, dtype)
|
||||
|
||||
if args.encode_only:
|
||||
print("[INFO] Encoding with T5...", flush=True)
|
||||
|
Loading…
Reference in New Issue
Block a user