change to from_pretrained

This commit is contained in:
Alex Barron 2024-10-08 15:46:04 -07:00
parent 9432f1a643
commit 4d2ee67402
12 changed files with 231 additions and 227 deletions

View File

@ -33,11 +33,11 @@ An example using the model:
```python
import mlx.core as mx
from encodec import load
from encodec import EncodecModel
from utils import load_audio, save_audio
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)

View File

@ -3,9 +3,10 @@
import time
import mlx.core as mx
from utils import load
model, processor = load("mlx-community/encodec-48khz-float32")
from encodec import EncodecModel
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
audio = mx.random.uniform(shape=(288000, 2))
feats, mask = processor(audio)

View File

@ -673,6 +673,33 @@ class EncodecModel(nn.Module):
audio_values = audio_values[:, : padding_mask.shape[1]]
return audio_values
@classmethod
def from_pretrained(cls, path_or_repo: str):
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
@ -712,33 +739,3 @@ def preprocess_audio(
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)
def load(path_or_repo):
"""
Load the model and audo preprocessor.
"""
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor

View File

@ -3,10 +3,10 @@
import mlx.core as mx
from utils import load_audio, save_audio
from encodec import load
from encodec import EncodecModel
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)

View File

@ -3,9 +3,10 @@
import mlx.core as mx
import numpy as np
import torch
from transformers import AutoProcessor, EncodecModel
from transformers import AutoProcessor
from transformers import EncodecModel as PTEncodecModel
from encodec import load, preprocess_audio
from encodec import EncodecModel, preprocess_audio
def compare_processors():
@ -30,8 +31,8 @@ def compare_processors():
def compare_models():
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
mx_model, _ = load("mlx-community/encodec-48khz-float32")
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
np.random.seed(0)
audio_length = 190560

View File

@ -1,12 +1,18 @@
# Copyright © 2024 Apple Inc.
import sys
import time
from pathlib import Path
import mlx.core as mx
from utils import load
cur_path = Path(__file__).parents[1].resolve()
sys.path.append(str(cur_path))
from musicgen import MusicGen
text = "folk ballad"
model = load("facebook/musicgen-medium")
model = MusicGen.from_pretrained("facebook/musicgen-medium")
max_steps = 100

View File

@ -1,3 +1,5 @@
# Copyright © 2024 Apple Inc.
import time
import torch

View File

@ -2,11 +2,13 @@
import argparse
from utils import load, save_audio
from utils import save_audio
from musicgen import MusicGen
def main(text: str, output_path: str, model_name: str, max_steps: int):
model = load(model_name)
model = MusicGen.from_pretrained(model_name)
audio = model.generate(text, max_steps=max_steps)
save_audio(output_path, audio, model.sampling_rate)

View File

@ -1,8 +1,10 @@
# Copyright © 2024 Apple Inc.
import json
import sys
from functools import partial
from pathlib import Path
from types import SimpleNamespace
from typing import Optional
import mlx.core as mx
@ -13,14 +15,14 @@ cur_path = Path(__file__).parents[1].resolve()
sys.path.append(str(cur_path / "encodec"))
sys.path.append(str(cur_path / "t5"))
from encodec import load as load_encodec
from t5 import load_model as load_t5
from encodec import EncodecModel
from t5 import T5
class TextConditioner(nn.Module):
def __init__(self, t5_name, input_dim, output_dim):
super().__init__()
self._t5, self.tokenizer = load_t5(t5_name)
self._t5, self.tokenizer = T5.from_pretrained(t5_name)
self.output_proj = nn.Linear(input_dim, output_dim)
def __call__(self, text):
@ -222,7 +224,9 @@ class MusicGen(nn.Module):
]
encodec_name = config.audio_encoder._name_or_path.split("/")[-1]
encodec_name = encodec_name.replace("_", "-")
self._audio_decoder, _ = load_encodec(f"mlx-community/{encodec_name}-float32")
self._audio_decoder, _ = EncodecModel.from_pretrained(
f"mlx-community/{encodec_name}-float32"
)
def __call__(
self,
@ -304,3 +308,57 @@ class MusicGen(nn.Module):
audio_seq = mx.swapaxes(audio_seq, -1, -2)[:, mx.newaxis]
audio = self._audio_decoder.decode(audio_seq, audio_scales=[None])
return audio[0]
@classmethod
def sanitize(cls, weights):
out_weights = {}
for k, arr in weights.items():
if k.startswith("transformer."):
k = k[len("transformer.") :]
if "cross_attention" in k:
k = k.replace("cross_attention", "cross_attn")
if "condition_provider" in k:
k = k.replace(
"condition_provider.conditioners.description", "text_conditioner"
)
if "in_proj_weight" in k:
dim = arr.shape[0] // 3
name = "in_proj_weight"
out_weights[k.replace(name, "q_proj.weight")] = arr[:dim]
out_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2]
out_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :]
continue
out_weights[k] = arr
return out_weights
@classmethod
def from_pretrained(cls, path_or_repo: str):
import torch
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "state_dict.bin"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
config.text_encoder = SimpleNamespace(**config.text_encoder)
config.audio_encoder = SimpleNamespace(**config.audio_encoder)
config.decoder = SimpleNamespace(**config.decoder)
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
weights = {k: mx.array(v.numpy()) for k, v in weights.items()}
weights = cls.sanitize(weights)
model = MusicGen(config)
model.load_weights(list(weights.items()))
return model

View File

@ -1,14 +1,7 @@
# Copyright © 2024 Apple Inc.
import json
from pathlib import Path
from types import SimpleNamespace
import mlx.core as mx
import numpy as np
from huggingface_hub import snapshot_download
import musicgen
def save_audio(file: str, audio: mx.array, sampling_rate: int):
@ -20,56 +13,3 @@ def save_audio(file: str, audio: mx.array, sampling_rate: int):
audio = mx.clip(audio, -1, 1)
audio = (audio * 32767).astype(mx.int16)
write(file, sampling_rate, np.array(audio))
def load(path_or_repo):
"""
Load the model and audio preprocessor.
"""
import torch
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "state_dict.bin"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
config.text_encoder = SimpleNamespace(**config.text_encoder)
config.audio_encoder = SimpleNamespace(**config.audio_encoder)
config.decoder = SimpleNamespace(**config.decoder)
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
weights = {k: mx.array(v.numpy()) for k, v in weights.items()}
decoder_weights = {}
for k, arr in weights.items():
if k.startswith("transformer."):
k = k[len("transformer.") :]
if "cross_attention" in k:
k = k.replace("cross_attention", "cross_attn")
if "condition_provider" in k:
k = k.replace(
"condition_provider.conditioners.description", "text_conditioner"
)
if "in_proj_weight" in k:
dim = arr.shape[0] // 3
name = "in_proj_weight"
decoder_weights[k.replace(name, "q_proj.weight")] = arr[:dim]
decoder_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2]
decoder_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :]
continue
decoder_weights[k] = arr
model = musicgen.MusicGen(config)
model.load_weights(list(decoder_weights.items()))
mx.eval(model)
return model

View File

@ -7,31 +7,6 @@ tasks by prepending task-specific prefixes to the input, e.g.:
This example also supports the FLAN-T5 models variants.[^2]
## Setup
Download and convert the model:
```sh
python convert.py --model <model>
```
This will make the `<model>.npz` file which MLX can read.
The `<model>` can be any of the following:
| Model Name | Model Size |
| ---------- | ----------
| t5-small | 60 million |
| t5-base | 220 million |
| t5-large | 770 million |
| t5-3b | 3 billion |
| t5-11b | 11 billion |
The FLAN variants can be specified with `google/flan-t5-small`,
`google/flan-t5-base`, etc. See the [Hugging Face
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
complete list of models.
## Generate
Generate text with:
@ -48,6 +23,21 @@ To see a list of options run:
python t5.py --help
```
The `<model>` can be any of the following:
| Model Name | Model Size |
| ---------- | ----------
| t5-small | 60 million |
| t5-base | 220 million |
| t5-large | 770 million |
| t5-3b | 3 billion |
| t5-11b | 11 billion |
The FLAN variants can be specified with `google/flan-t5-small`,
`google/flan-t5-base`, etc. See the [Hugging Face
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
complete list of models.
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).

195
t5/t5.py
View File

@ -10,36 +10,36 @@ import mlx.nn as nn
import numpy as np
from transformers import AutoTokenizer
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
class Tokenizer:
def __init__(self, config, model_name):
self._decoder_start_id = config.decoder_start_token_id
self._tokenizer = AutoTokenizer.from_pretrained(
model_name,
legacy=False,
model_max_length=getattr(config, "n_positions", 512),
)
DECODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
IGNORED_KEYS = ["decoder.layers.0.cross_attention.relative_attention_bias.weight"]
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
)
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
def _relative_position_bucket(
@ -350,36 +350,83 @@ class T5(nn.Module):
):
return self.decode(decoder_inputs, self.encode(inputs))[0]
@classmethod
def sanitize(cls, weights):
shared_replacement_patterns = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
class Tokenizer:
def __init__(self, config, model_name):
self._decoder_start_id = config.decoder_start_token_id
self._tokenizer = AutoTokenizer.from_pretrained(
model_name,
legacy=False,
model_max_length=getattr(config, "n_positions", 512),
)
encoder_replacement_patterns = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
decoder_replacement_patterns = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
ignored_keys = [
"decoder.layers.0.cross_attention.relative_attention_bias.weight"
]
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
)
def replace_key(key: str) -> str:
for old, new in shared_replacement_patterns:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in encoder_replacement_patterns:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in decoder_replacement_patterns:
key = key.replace(old, new)
return key
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
weights = {replace_key(k): v for k, v in weights.items()}
for key in ignored_keys:
if key in weights:
del weights[key]
return weights
@classmethod
def from_pretrained(
cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16
) -> tuple["T5", Tokenizer]:
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
print(path)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = T5(config)
weights = mx.load(str(path / "model.safetensors"))
weights = cls.sanitize(weights)
weights = {k: v.astype(dtype) for k, v in weights.items()}
model.load_weights(list(weights.items()))
return model, Tokenizer(config, "t5-base")
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
@ -400,47 +447,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float]
yield y.squeeze()
def replace_key(key: str) -> str:
for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in ENCODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in DECODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
return key
def load_model(path_or_repo: str, dtype: str = "bfloat16"):
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
print(path)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
dtype = getattr(mx, dtype)
model = T5(config)
weights = mx.load(str(path / "model.safetensors"))
weights = {replace_key(k): v.astype(dtype) for k, v in weights.items()}
for key in IGNORED_KEYS:
del weights[key]
model.load_weights(list(weights.items()))
mx.eval(model.parameters())
return model, Tokenizer(config, "t5-base")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="T5 Inference script")
parser.add_argument(
@ -486,7 +492,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model, args.dtype)
dtype = getattr(mx, args.dtype)
model, tokenizer = T5.from_pretrained(args.model, dtype)
if args.encode_only:
print("[INFO] Encoding with T5...", flush=True)