mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
change to from_pretrained
This commit is contained in:
@@ -33,11 +33,11 @@ An example using the model:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from encodec import load
|
from encodec import EncodecModel
|
||||||
from utils import load_audio, save_audio
|
from utils import load_audio, save_audio
|
||||||
|
|
||||||
# Load the 48 KHz model and preprocessor.
|
# Load the 48 KHz model and preprocessor.
|
||||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
# Load an audio file
|
# Load an audio file
|
||||||
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
|
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
|
||||||
|
@@ -3,9 +3,10 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from utils import load
|
|
||||||
|
|
||||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
from encodec import EncodecModel
|
||||||
|
|
||||||
|
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
audio = mx.random.uniform(shape=(288000, 2))
|
audio = mx.random.uniform(shape=(288000, 2))
|
||||||
feats, mask = processor(audio)
|
feats, mask = processor(audio)
|
||||||
|
@@ -673,6 +673,33 @@ class EncodecModel(nn.Module):
|
|||||||
audio_values = audio_values[:, : padding_mask.shape[1]]
|
audio_values = audio_values[:, : padding_mask.shape[1]]
|
||||||
return audio_values
|
return audio_values
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, path_or_repo: str):
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
path = Path(path_or_repo)
|
||||||
|
if not path.exists():
|
||||||
|
path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=path_or_repo,
|
||||||
|
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(path / "config.json", "r") as f:
|
||||||
|
config = SimpleNamespace(**json.load(f))
|
||||||
|
|
||||||
|
model = EncodecModel(config)
|
||||||
|
model.load_weights(str(path / "model.safetensors"))
|
||||||
|
processor = functools.partial(
|
||||||
|
preprocess_audio,
|
||||||
|
sampling_rate=config.sampling_rate,
|
||||||
|
chunk_length=model.chunk_length,
|
||||||
|
chunk_stride=model.chunk_stride,
|
||||||
|
)
|
||||||
|
mx.eval(model)
|
||||||
|
return model, processor
|
||||||
|
|
||||||
|
|
||||||
def preprocess_audio(
|
def preprocess_audio(
|
||||||
raw_audio: Union[mx.array, List[mx.array]],
|
raw_audio: Union[mx.array, List[mx.array]],
|
||||||
@@ -712,33 +739,3 @@ def preprocess_audio(
|
|||||||
inputs.append(x)
|
inputs.append(x)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
return mx.stack(inputs), mx.stack(masks)
|
return mx.stack(inputs), mx.stack(masks)
|
||||||
|
|
||||||
|
|
||||||
def load(path_or_repo):
|
|
||||||
"""
|
|
||||||
Load the model and audo preprocessor.
|
|
||||||
"""
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
path = Path(path_or_repo)
|
|
||||||
if not path.exists():
|
|
||||||
path = Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=path_or_repo,
|
|
||||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(path / "config.json", "r") as f:
|
|
||||||
config = SimpleNamespace(**json.load(f))
|
|
||||||
|
|
||||||
model = EncodecModel(config)
|
|
||||||
model.load_weights(str(path / "model.safetensors"))
|
|
||||||
processor = functools.partial(
|
|
||||||
preprocess_audio,
|
|
||||||
sampling_rate=config.sampling_rate,
|
|
||||||
chunk_length=model.chunk_length,
|
|
||||||
chunk_stride=model.chunk_stride,
|
|
||||||
)
|
|
||||||
mx.eval(model)
|
|
||||||
return model, processor
|
|
||||||
|
@@ -3,10 +3,10 @@
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from utils import load_audio, save_audio
|
from utils import load_audio, save_audio
|
||||||
|
|
||||||
from encodec import load
|
from encodec import EncodecModel
|
||||||
|
|
||||||
# Load the 48 KHz model and preprocessor.
|
# Load the 48 KHz model and preprocessor.
|
||||||
model, processor = load("mlx-community/encodec-48khz-float32")
|
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
# Load an audio file
|
# Load an audio file
|
||||||
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)
|
||||||
|
@@ -3,9 +3,10 @@
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoProcessor, EncodecModel
|
from transformers import AutoProcessor
|
||||||
|
from transformers import EncodecModel as PTEncodecModel
|
||||||
|
|
||||||
from encodec import load, preprocess_audio
|
from encodec import EncodecModel, preprocess_audio
|
||||||
|
|
||||||
|
|
||||||
def compare_processors():
|
def compare_processors():
|
||||||
@@ -30,8 +31,8 @@ def compare_processors():
|
|||||||
|
|
||||||
|
|
||||||
def compare_models():
|
def compare_models():
|
||||||
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
|
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
|
||||||
mx_model, _ = load("mlx-community/encodec-48khz-float32")
|
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
audio_length = 190560
|
audio_length = 190560
|
||||||
|
@@ -1,12 +1,18 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from utils import load
|
|
||||||
|
cur_path = Path(__file__).parents[1].resolve()
|
||||||
|
sys.path.append(str(cur_path))
|
||||||
|
|
||||||
|
from musicgen import MusicGen
|
||||||
|
|
||||||
text = "folk ballad"
|
text = "folk ballad"
|
||||||
model = load("facebook/musicgen-medium")
|
model = MusicGen.from_pretrained("facebook/musicgen-medium")
|
||||||
|
|
||||||
max_steps = 100
|
max_steps = 100
|
||||||
|
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@@ -2,11 +2,13 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from utils import load, save_audio
|
from utils import save_audio
|
||||||
|
|
||||||
|
from musicgen import MusicGen
|
||||||
|
|
||||||
|
|
||||||
def main(text: str, output_path: str, model_name: str, max_steps: int):
|
def main(text: str, output_path: str, model_name: str, max_steps: int):
|
||||||
model = load(model_name)
|
model = MusicGen.from_pretrained(model_name)
|
||||||
audio = model.generate(text, max_steps=max_steps)
|
audio = model.generate(text, max_steps=max_steps)
|
||||||
save_audio(output_path, audio, model.sampling_rate)
|
save_audio(output_path, audio, model.sampling_rate)
|
||||||
|
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@@ -13,14 +15,14 @@ cur_path = Path(__file__).parents[1].resolve()
|
|||||||
sys.path.append(str(cur_path / "encodec"))
|
sys.path.append(str(cur_path / "encodec"))
|
||||||
sys.path.append(str(cur_path / "t5"))
|
sys.path.append(str(cur_path / "t5"))
|
||||||
|
|
||||||
from encodec import load as load_encodec
|
from encodec import EncodecModel
|
||||||
from t5 import load_model as load_t5
|
from t5 import T5
|
||||||
|
|
||||||
|
|
||||||
class TextConditioner(nn.Module):
|
class TextConditioner(nn.Module):
|
||||||
def __init__(self, t5_name, input_dim, output_dim):
|
def __init__(self, t5_name, input_dim, output_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._t5, self.tokenizer = load_t5(t5_name)
|
self._t5, self.tokenizer = T5.from_pretrained(t5_name)
|
||||||
self.output_proj = nn.Linear(input_dim, output_dim)
|
self.output_proj = nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
def __call__(self, text):
|
def __call__(self, text):
|
||||||
@@ -222,7 +224,9 @@ class MusicGen(nn.Module):
|
|||||||
]
|
]
|
||||||
encodec_name = config.audio_encoder._name_or_path.split("/")[-1]
|
encodec_name = config.audio_encoder._name_or_path.split("/")[-1]
|
||||||
encodec_name = encodec_name.replace("_", "-")
|
encodec_name = encodec_name.replace("_", "-")
|
||||||
self._audio_decoder, _ = load_encodec(f"mlx-community/{encodec_name}-float32")
|
self._audio_decoder, _ = EncodecModel.from_pretrained(
|
||||||
|
f"mlx-community/{encodec_name}-float32"
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -304,3 +308,57 @@ class MusicGen(nn.Module):
|
|||||||
audio_seq = mx.swapaxes(audio_seq, -1, -2)[:, mx.newaxis]
|
audio_seq = mx.swapaxes(audio_seq, -1, -2)[:, mx.newaxis]
|
||||||
audio = self._audio_decoder.decode(audio_seq, audio_scales=[None])
|
audio = self._audio_decoder.decode(audio_seq, audio_scales=[None])
|
||||||
return audio[0]
|
return audio[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sanitize(cls, weights):
|
||||||
|
out_weights = {}
|
||||||
|
for k, arr in weights.items():
|
||||||
|
if k.startswith("transformer."):
|
||||||
|
k = k[len("transformer.") :]
|
||||||
|
|
||||||
|
if "cross_attention" in k:
|
||||||
|
k = k.replace("cross_attention", "cross_attn")
|
||||||
|
|
||||||
|
if "condition_provider" in k:
|
||||||
|
k = k.replace(
|
||||||
|
"condition_provider.conditioners.description", "text_conditioner"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "in_proj_weight" in k:
|
||||||
|
dim = arr.shape[0] // 3
|
||||||
|
name = "in_proj_weight"
|
||||||
|
out_weights[k.replace(name, "q_proj.weight")] = arr[:dim]
|
||||||
|
out_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2]
|
||||||
|
out_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :]
|
||||||
|
continue
|
||||||
|
|
||||||
|
out_weights[k] = arr
|
||||||
|
return out_weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, path_or_repo: str):
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
path = Path(path_or_repo)
|
||||||
|
if not path.exists():
|
||||||
|
path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=path_or_repo,
|
||||||
|
allow_patterns=["*.json", "state_dict.bin"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(path / "config.json", "r") as f:
|
||||||
|
config = SimpleNamespace(**json.load(f))
|
||||||
|
config.text_encoder = SimpleNamespace(**config.text_encoder)
|
||||||
|
config.audio_encoder = SimpleNamespace(**config.audio_encoder)
|
||||||
|
config.decoder = SimpleNamespace(**config.decoder)
|
||||||
|
|
||||||
|
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
|
||||||
|
weights = {k: mx.array(v.numpy()) for k, v in weights.items()}
|
||||||
|
weights = cls.sanitize(weights)
|
||||||
|
|
||||||
|
model = MusicGen(config)
|
||||||
|
model.load_weights(list(weights.items()))
|
||||||
|
return model
|
||||||
|
@@ -1,14 +1,7 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
import musicgen
|
|
||||||
|
|
||||||
|
|
||||||
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
||||||
@@ -20,56 +13,3 @@ def save_audio(file: str, audio: mx.array, sampling_rate: int):
|
|||||||
audio = mx.clip(audio, -1, 1)
|
audio = mx.clip(audio, -1, 1)
|
||||||
audio = (audio * 32767).astype(mx.int16)
|
audio = (audio * 32767).astype(mx.int16)
|
||||||
write(file, sampling_rate, np.array(audio))
|
write(file, sampling_rate, np.array(audio))
|
||||||
|
|
||||||
|
|
||||||
def load(path_or_repo):
|
|
||||||
"""
|
|
||||||
Load the model and audio preprocessor.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
path = Path(path_or_repo)
|
|
||||||
if not path.exists():
|
|
||||||
path = Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=path_or_repo,
|
|
||||||
allow_patterns=["*.json", "state_dict.bin"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(path / "config.json", "r") as f:
|
|
||||||
config = SimpleNamespace(**json.load(f))
|
|
||||||
config.text_encoder = SimpleNamespace(**config.text_encoder)
|
|
||||||
config.audio_encoder = SimpleNamespace(**config.audio_encoder)
|
|
||||||
config.decoder = SimpleNamespace(**config.decoder)
|
|
||||||
|
|
||||||
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
|
|
||||||
weights = {k: mx.array(v.numpy()) for k, v in weights.items()}
|
|
||||||
|
|
||||||
decoder_weights = {}
|
|
||||||
for k, arr in weights.items():
|
|
||||||
if k.startswith("transformer."):
|
|
||||||
k = k[len("transformer.") :]
|
|
||||||
|
|
||||||
if "cross_attention" in k:
|
|
||||||
k = k.replace("cross_attention", "cross_attn")
|
|
||||||
|
|
||||||
if "condition_provider" in k:
|
|
||||||
k = k.replace(
|
|
||||||
"condition_provider.conditioners.description", "text_conditioner"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "in_proj_weight" in k:
|
|
||||||
dim = arr.shape[0] // 3
|
|
||||||
name = "in_proj_weight"
|
|
||||||
decoder_weights[k.replace(name, "q_proj.weight")] = arr[:dim]
|
|
||||||
decoder_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2]
|
|
||||||
decoder_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :]
|
|
||||||
continue
|
|
||||||
|
|
||||||
decoder_weights[k] = arr
|
|
||||||
|
|
||||||
model = musicgen.MusicGen(config)
|
|
||||||
model.load_weights(list(decoder_weights.items()))
|
|
||||||
mx.eval(model)
|
|
||||||
return model
|
|
||||||
|
40
t5/README.md
40
t5/README.md
@@ -7,31 +7,6 @@ tasks by prepending task-specific prefixes to the input, e.g.:
|
|||||||
|
|
||||||
This example also supports the FLAN-T5 models variants.[^2]
|
This example also supports the FLAN-T5 models variants.[^2]
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
Download and convert the model:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
python convert.py --model <model>
|
|
||||||
```
|
|
||||||
|
|
||||||
This will make the `<model>.npz` file which MLX can read.
|
|
||||||
|
|
||||||
The `<model>` can be any of the following:
|
|
||||||
|
|
||||||
| Model Name | Model Size |
|
|
||||||
| ---------- | ----------
|
|
||||||
| t5-small | 60 million |
|
|
||||||
| t5-base | 220 million |
|
|
||||||
| t5-large | 770 million |
|
|
||||||
| t5-3b | 3 billion |
|
|
||||||
| t5-11b | 11 billion |
|
|
||||||
|
|
||||||
The FLAN variants can be specified with `google/flan-t5-small`,
|
|
||||||
`google/flan-t5-base`, etc. See the [Hugging Face
|
|
||||||
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
|
|
||||||
complete list of models.
|
|
||||||
|
|
||||||
## Generate
|
## Generate
|
||||||
|
|
||||||
Generate text with:
|
Generate text with:
|
||||||
@@ -48,6 +23,21 @@ To see a list of options run:
|
|||||||
python t5.py --help
|
python t5.py --help
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The `<model>` can be any of the following:
|
||||||
|
|
||||||
|
| Model Name | Model Size |
|
||||||
|
| ---------- | ----------
|
||||||
|
| t5-small | 60 million |
|
||||||
|
| t5-base | 220 million |
|
||||||
|
| t5-large | 770 million |
|
||||||
|
| t5-3b | 3 billion |
|
||||||
|
| t5-11b | 11 billion |
|
||||||
|
|
||||||
|
The FLAN variants can be specified with `google/flan-t5-small`,
|
||||||
|
`google/flan-t5-base`, etc. See the [Hugging Face
|
||||||
|
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
|
||||||
|
complete list of models.
|
||||||
|
|
||||||
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
|
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
|
||||||
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
|
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
|
||||||
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).
|
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).
|
||||||
|
195
t5/t5.py
195
t5/t5.py
@@ -10,36 +10,36 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
SHARED_REPLACEMENT_PATTERNS = [
|
|
||||||
(".block.", ".layers."),
|
|
||||||
(".k.", ".key_proj."),
|
|
||||||
(".o.", ".out_proj."),
|
|
||||||
(".q.", ".query_proj."),
|
|
||||||
(".v.", ".value_proj."),
|
|
||||||
("shared.", "wte."),
|
|
||||||
("lm_head.", "lm_head.linear."),
|
|
||||||
(".layer.0.layer_norm.", ".ln1."),
|
|
||||||
(".layer.1.layer_norm.", ".ln2."),
|
|
||||||
(".layer.2.layer_norm.", ".ln3."),
|
|
||||||
(".final_layer_norm.", ".ln."),
|
|
||||||
(
|
|
||||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
|
||||||
"relative_attention_bias.embeddings.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
ENCODER_REPLACEMENT_PATTERNS = [
|
class Tokenizer:
|
||||||
(".layer.0.SelfAttention.", ".attention."),
|
def __init__(self, config, model_name):
|
||||||
(".layer.1.DenseReluDense.", ".dense."),
|
self._decoder_start_id = config.decoder_start_token_id
|
||||||
]
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
legacy=False,
|
||||||
|
model_max_length=getattr(config, "n_positions", 512),
|
||||||
|
)
|
||||||
|
|
||||||
DECODER_REPLACEMENT_PATTERNS = [
|
@property
|
||||||
(".layer.0.SelfAttention.", ".self_attention."),
|
def eos_id(self) -> int:
|
||||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
return self._tokenizer.eos_token_id
|
||||||
(".layer.2.DenseReluDense.", ".dense."),
|
|
||||||
]
|
|
||||||
|
|
||||||
IGNORED_KEYS = ["decoder.layers.0.cross_attention.relative_attention_bias.weight"]
|
@property
|
||||||
|
def decoder_start_id(self) -> int:
|
||||||
|
return self._decoder_start_id
|
||||||
|
|
||||||
|
def encode(self, s: str) -> mx.array:
|
||||||
|
return mx.array(
|
||||||
|
self._tokenizer(
|
||||||
|
s,
|
||||||
|
return_tensors="np",
|
||||||
|
return_attention_mask=False,
|
||||||
|
)["input_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||||
|
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||||
|
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||||
|
|
||||||
|
|
||||||
def _relative_position_bucket(
|
def _relative_position_bucket(
|
||||||
@@ -350,36 +350,83 @@ class T5(nn.Module):
|
|||||||
):
|
):
|
||||||
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sanitize(cls, weights):
|
||||||
|
shared_replacement_patterns = [
|
||||||
|
(".block.", ".layers."),
|
||||||
|
(".k.", ".key_proj."),
|
||||||
|
(".o.", ".out_proj."),
|
||||||
|
(".q.", ".query_proj."),
|
||||||
|
(".v.", ".value_proj."),
|
||||||
|
("shared.", "wte."),
|
||||||
|
("lm_head.", "lm_head.linear."),
|
||||||
|
(".layer.0.layer_norm.", ".ln1."),
|
||||||
|
(".layer.1.layer_norm.", ".ln2."),
|
||||||
|
(".layer.2.layer_norm.", ".ln3."),
|
||||||
|
(".final_layer_norm.", ".ln."),
|
||||||
|
(
|
||||||
|
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||||
|
"relative_attention_bias.embeddings.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
class Tokenizer:
|
encoder_replacement_patterns = [
|
||||||
def __init__(self, config, model_name):
|
(".layer.0.SelfAttention.", ".attention."),
|
||||||
self._decoder_start_id = config.decoder_start_token_id
|
(".layer.1.DenseReluDense.", ".dense."),
|
||||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
]
|
||||||
model_name,
|
|
||||||
legacy=False,
|
|
||||||
model_max_length=getattr(config, "n_positions", 512),
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
decoder_replacement_patterns = [
|
||||||
def eos_id(self) -> int:
|
(".layer.0.SelfAttention.", ".self_attention."),
|
||||||
return self._tokenizer.eos_token_id
|
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||||
|
(".layer.2.DenseReluDense.", ".dense."),
|
||||||
|
]
|
||||||
|
|
||||||
@property
|
ignored_keys = [
|
||||||
def decoder_start_id(self) -> int:
|
"decoder.layers.0.cross_attention.relative_attention_bias.weight"
|
||||||
return self._decoder_start_id
|
]
|
||||||
|
|
||||||
def encode(self, s: str) -> mx.array:
|
def replace_key(key: str) -> str:
|
||||||
return mx.array(
|
for old, new in shared_replacement_patterns:
|
||||||
self._tokenizer(
|
key = key.replace(old, new)
|
||||||
s,
|
if key.startswith("encoder."):
|
||||||
return_tensors="np",
|
for old, new in encoder_replacement_patterns:
|
||||||
return_attention_mask=False,
|
key = key.replace(old, new)
|
||||||
)["input_ids"]
|
elif key.startswith("decoder."):
|
||||||
)
|
for old, new in decoder_replacement_patterns:
|
||||||
|
key = key.replace(old, new)
|
||||||
|
return key
|
||||||
|
|
||||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
weights = {replace_key(k): v for k, v in weights.items()}
|
||||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
for key in ignored_keys:
|
||||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
if key in weights:
|
||||||
|
del weights[key]
|
||||||
|
return weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16
|
||||||
|
) -> tuple["T5", Tokenizer]:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
path = Path(path_or_repo)
|
||||||
|
if not path.exists():
|
||||||
|
path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=path_or_repo,
|
||||||
|
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(path)
|
||||||
|
|
||||||
|
with open(path / "config.json", "r") as f:
|
||||||
|
config = SimpleNamespace(**json.load(f))
|
||||||
|
|
||||||
|
model = T5(config)
|
||||||
|
weights = mx.load(str(path / "model.safetensors"))
|
||||||
|
weights = cls.sanitize(weights)
|
||||||
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||||
|
model.load_weights(list(weights.items()))
|
||||||
|
return model, Tokenizer(config, "t5-base")
|
||||||
|
|
||||||
|
|
||||||
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
|
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
|
||||||
@@ -400,47 +447,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float]
|
|||||||
yield y.squeeze()
|
yield y.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def replace_key(key: str) -> str:
|
|
||||||
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
|
||||||
key = key.replace(old, new)
|
|
||||||
if key.startswith("encoder."):
|
|
||||||
for old, new in ENCODER_REPLACEMENT_PATTERNS:
|
|
||||||
key = key.replace(old, new)
|
|
||||||
elif key.startswith("decoder."):
|
|
||||||
for old, new in DECODER_REPLACEMENT_PATTERNS:
|
|
||||||
key = key.replace(old, new)
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(path_or_repo: str, dtype: str = "bfloat16"):
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
path = Path(path_or_repo)
|
|
||||||
if not path.exists():
|
|
||||||
path = Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=path_or_repo,
|
|
||||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print(path)
|
|
||||||
|
|
||||||
with open(path / "config.json", "r") as f:
|
|
||||||
config = SimpleNamespace(**json.load(f))
|
|
||||||
|
|
||||||
dtype = getattr(mx, dtype)
|
|
||||||
|
|
||||||
model = T5(config)
|
|
||||||
weights = mx.load(str(path / "model.safetensors"))
|
|
||||||
weights = {replace_key(k): v.astype(dtype) for k, v in weights.items()}
|
|
||||||
for key in IGNORED_KEYS:
|
|
||||||
del weights[key]
|
|
||||||
model.load_weights(list(weights.items()))
|
|
||||||
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
return model, Tokenizer(config, "t5-base")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="T5 Inference script")
|
parser = argparse.ArgumentParser(description="T5 Inference script")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -486,7 +492,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
model, tokenizer = load_model(args.model, args.dtype)
|
dtype = getattr(mx, args.dtype)
|
||||||
|
model, tokenizer = T5.from_pretrained(args.model, dtype)
|
||||||
|
|
||||||
if args.encode_only:
|
if args.encode_only:
|
||||||
print("[INFO] Encoding with T5...", flush=True)
|
print("[INFO] Encoding with T5...", flush=True)
|
||||||
|
Reference in New Issue
Block a user