mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-02 13:40:48 +08:00
Make T5 work with official models without conversions
This commit is contained in:
40
t5/README.md
40
t5/README.md
@@ -9,15 +9,29 @@ This example also supports the FLAN-T5 models variants.[^2]
|
||||
|
||||
## Setup
|
||||
|
||||
Download and convert the model:
|
||||
Install the dependencies:
|
||||
|
||||
```sh
|
||||
python convert.py --model <model>
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
This will make the `<model>.npz` file which MLX can read.
|
||||
## Generate
|
||||
|
||||
The `<model>` can be any of the following:
|
||||
Generate text with:
|
||||
|
||||
```sh
|
||||
python t5.py --model google-t5/t5-small --prompt "translate English to German: A tasty apple"
|
||||
```
|
||||
|
||||
This should give the output: `Ein leckerer Apfel`
|
||||
|
||||
To see a list of options run:
|
||||
|
||||
```sh
|
||||
python t5.py --help
|
||||
```
|
||||
|
||||
The `google-t5` has following models:
|
||||
|
||||
| Model Name | Model Size |
|
||||
| ---------- | ----------
|
||||
@@ -32,22 +46,6 @@ The FLAN variants can be specified with `google/flan-t5-small`,
|
||||
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
|
||||
complete list of models.
|
||||
|
||||
## Generate
|
||||
|
||||
Generate text with:
|
||||
|
||||
```sh
|
||||
python t5.py --model t5-small --prompt "translate English to German: A tasty apple"
|
||||
```
|
||||
|
||||
This should give the output: `Ein leckerer Apfel`
|
||||
|
||||
To see a list of options run:
|
||||
|
||||
```sh
|
||||
python t5.py --help
|
||||
```
|
||||
|
||||
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
|
||||
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
|
||||
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).
|
||||
|
@@ -1,75 +0,0 @@
|
||||
import numpy as np
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
ENCODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
DECODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".self_attention."),
|
||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||
(".layer.2.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
|
||||
def replace_key(key: str) -> str:
|
||||
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
if key.startswith("encoder."):
|
||||
for old, new in ENCODER_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
elif key.startswith("decoder."):
|
||||
for old, new in DECODER_REPLACEMENT_PATTERNS:
|
||||
key = key.replace(old, new)
|
||||
return key
|
||||
|
||||
|
||||
def convert(model_name, dtype):
|
||||
dtype = getattr(np, dtype)
|
||||
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
|
||||
weights = {
|
||||
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
|
||||
}
|
||||
file_name = model_name.replace("/", "-")
|
||||
print(f"Saving weights to {file_name}.npz")
|
||||
np.savez(f"{file_name}.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="Name of the T5 model.",
|
||||
default="t5-small",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
help="The model data type.",
|
||||
type=str,
|
||||
choices=["float16", "float32"],
|
||||
default="float32",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(args.model, args.dtype)
|
227
t5/t5.py
227
t5/t5.py
@@ -1,10 +1,14 @@
|
||||
import argparse
|
||||
import glob
|
||||
import math
|
||||
from pathlib import Path
|
||||
from time import perf_counter_ns
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from transformers import AutoTokenizer, T5Config
|
||||
|
||||
@@ -65,11 +69,8 @@ class RelativePositionBias(nn.Module):
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(
|
||||
config.relative_attention_num_buckets, config.num_heads
|
||||
)
|
||||
|
||||
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
||||
def __call__(self, embeddings: nn.Embedding, query_length: int, key_length: int, offset: int = 0):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = mx.arange(offset, query_length)[:, None]
|
||||
memory_position = mx.arange(key_length)[None, :]
|
||||
@@ -84,21 +85,23 @@ class RelativePositionBias(nn.Module):
|
||||
)
|
||||
|
||||
# shape (query_length, key_length, num_heads)
|
||||
values = self.embeddings(relative_position_bucket)
|
||||
values = embeddings(relative_position_bucket)
|
||||
|
||||
# shape (num_heads, query_length, key_length)
|
||||
return values.transpose(2, 0, 1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config: T5Config, has_relative_attention_bias: bool):
|
||||
super().__init__()
|
||||
inner_dim = config.d_kv * config.num_heads
|
||||
self.num_heads = config.num_heads
|
||||
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||
self.q = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.k = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.v = nn.Linear(config.d_model, inner_dim, bias=False)
|
||||
self.o = nn.Linear(inner_dim, config.d_model, bias=False)
|
||||
if has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embedding(config.relative_attention_num_buckets, self.num_heads)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -108,30 +111,67 @@ class MultiHeadAttention(nn.Module):
|
||||
mask: Optional[mx.array],
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
queries = self.q(queries)
|
||||
keys = self.k(keys)
|
||||
values = self.v(values)
|
||||
|
||||
num_heads = self.num_heads
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
keys = mx.concatenate([key_cache, keys], axis=3)
|
||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
|
||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||
scores = queries @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
# scale = math.sqrt(1 / queries.shape[-1])
|
||||
# output = mx.fast.scaled_dot_product_attention(
|
||||
# queries, keys, values, scale=scale, mask=mask
|
||||
# )
|
||||
# output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
scores = queries @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask.astype(scores.dtype)
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.out_proj(values_hat), (keys, values)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o(output), (keys, values)
|
||||
|
||||
|
||||
class LayerSelfAttention(nn.Module):
|
||||
def __init__(self, config: T5Config, has_relative_attention_bias: bool):
|
||||
super().__init__()
|
||||
self.SelfAttention = MultiHeadAttention(config, has_relative_attention_bias)
|
||||
self.layer_norm = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array],
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||
y = self.layer_norm(x)
|
||||
return self.SelfAttention(y, y, y, mask, cache)
|
||||
|
||||
|
||||
class LayerCrossAttention(nn.Module):
|
||||
def __init__(self, config: T5Config, has_relative_attention_bias: bool):
|
||||
super().__init__()
|
||||
self.EncDecAttention = MultiHeadAttention(config, has_relative_attention_bias)
|
||||
self.layer_norm = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
memory: mx.array,
|
||||
memory_mask: Optional[mx.array],
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||
y = self.layer_norm(x)
|
||||
return self.EncDecAttention(y, memory, memory, memory_mask, cache)
|
||||
|
||||
|
||||
class DenseActivation(nn.Module):
|
||||
@@ -165,49 +205,59 @@ class DenseActivation(nn.Module):
|
||||
return self.wo(x)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
class LayerFF(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
self.DenseReluDense = DenseActivation(config)
|
||||
self.layer_norm = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, x: mx.array) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||
return self.DenseReluDense(self.layer_norm(x))
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config, has_relative_attention_bias: bool):
|
||||
super().__init__()
|
||||
self.layer = [
|
||||
LayerSelfAttention(config, has_relative_attention_bias),
|
||||
LayerFF(config)
|
||||
]
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y, _ = self.attention(y, y, y, mask=mask)
|
||||
y, _ = self.layer[0](x, mask=mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.dense(y)
|
||||
return x + y
|
||||
y = self.layer[1](x)
|
||||
x = x + y
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
self.block = [
|
||||
TransformerEncoderLayer(config, i == 0) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.final_layer_norm = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||
for layer in self.layers:
|
||||
pos_bias = self.relative_attention_bias(
|
||||
self.block[0].layer[0].SelfAttention.relative_attention_bias,
|
||||
x.shape[1],
|
||||
x.shape[1])
|
||||
for layer in self.block:
|
||||
x = layer(x, mask=pos_bias)
|
||||
return self.ln(x)
|
||||
return self.final_layer_norm(x)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config: T5Config, has_relative_attention_bias: bool):
|
||||
super().__init__()
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
self.layer = [
|
||||
LayerSelfAttention(config, has_relative_attention_bias),
|
||||
LayerCrossAttention(config, has_relative_attention_bias),
|
||||
LayerFF(config)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -217,18 +267,12 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory_mask: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
):
|
||||
y = self.ln1(x)
|
||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||
y, cache = self.layer[0](x, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y, _ = self.cross_attention(y, memory, memory, memory_mask)
|
||||
y, _ = self.layer[1](x, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln3(x)
|
||||
y = self.dense(y)
|
||||
y = self.layer[2](x)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
@@ -236,51 +280,49 @@ class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.block = [
|
||||
TransformerDecoderLayer(config, i == 0) for i in range(n_layers)
|
||||
]
|
||||
self.final_layer_norm = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||
|
||||
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||
if cache is not None:
|
||||
offset = cache[0][0].shape[3]
|
||||
offset = cache[0][0].shape[2]
|
||||
else:
|
||||
offset = 0
|
||||
cache = [None] * len(self.layers)
|
||||
cache = [None] * len(self.block)
|
||||
|
||||
T = offset + x.shape[1]
|
||||
pos_bias = self.relative_attention_bias(T, T, offset=offset)
|
||||
pos_bias = self.relative_attention_bias(
|
||||
self.block[0].layer[0].SelfAttention.relative_attention_bias,
|
||||
T,
|
||||
T,
|
||||
offset)
|
||||
if mask is not None:
|
||||
mask += pos_bias
|
||||
else:
|
||||
mask = pos_bias
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
for e, layer in enumerate(self.block):
|
||||
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
|
||||
x = self.ln(x)
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(inputs)
|
||||
|
||||
|
||||
class T5(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
self.decoder = TransformerDecoder(config)
|
||||
self.tie_word_embeddings = config.tie_word_embeddings
|
||||
if not self.tie_word_embeddings:
|
||||
self.lm_head = OutputHead(config)
|
||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
def encode(self, inputs: mx.array):
|
||||
return self.encoder(self.wte(inputs))
|
||||
return self.encoder(self.shared(inputs))
|
||||
|
||||
def decode(
|
||||
self,
|
||||
@@ -288,7 +330,7 @@ class T5(nn.Module):
|
||||
memory: mx.array,
|
||||
cache=None,
|
||||
):
|
||||
inputs = self.wte(inputs)
|
||||
inputs = self.shared(inputs)
|
||||
T = inputs.shape[1]
|
||||
if T > 1:
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(T)
|
||||
@@ -303,7 +345,7 @@ class T5(nn.Module):
|
||||
y = self.lm_head(y)
|
||||
else:
|
||||
y *= self.model_dim**-0.5
|
||||
y = y @ self.wte.weight.T
|
||||
y = y @ self.shared.weight.T
|
||||
return y, cache
|
||||
|
||||
def __call__(
|
||||
@@ -314,6 +356,21 @@ class T5(nn.Module):
|
||||
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(config: T5Config, path: Path):
|
||||
weight_files = glob.glob(str(path / "*.safetensors"))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(f"No safetensors found in {path}")
|
||||
|
||||
weights = {}
|
||||
for wf in weight_files:
|
||||
weights.update(mx.load(wf))
|
||||
|
||||
model = T5(config)
|
||||
model.load_weights(list(weights.items()))
|
||||
return model
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, config: T5Config):
|
||||
self._decoder_start_id = config.decoder_start_token_id
|
||||
@@ -363,15 +420,21 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float]
|
||||
yield y.squeeze()
|
||||
|
||||
|
||||
def load_model(model_name: str, dtype: str = "float16"):
|
||||
config = T5Config.from_pretrained(args.model)
|
||||
def load_model(path_or_hf_repo: str, dtype: str = "float16"):
|
||||
path = Path(path_or_hf_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_hf_repo,
|
||||
allow_patterns=[
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
],
|
||||
)
|
||||
)
|
||||
config = T5Config.from_pretrained(path)
|
||||
dtype = getattr(mx, dtype)
|
||||
model = T5(config)
|
||||
file_name = model_name.replace("/", "-")
|
||||
weights = mx.load(f"{file_name}.npz")
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model.update(weights)
|
||||
model = T5.from_pretrained(config, path)
|
||||
mx.eval(model.parameters())
|
||||
return model, Tokenizer(config)
|
||||
|
||||
|
Reference in New Issue
Block a user