* Add MusicGen model

* add benchmarks

* change to from_pretrained

* symlinks

* add readme and requirements

* fix readme

* readme
This commit is contained in:
Alex Barron
2024-10-11 10:16:20 -07:00
committed by GitHub
parent 4360e7ccec
commit d72fdeb4ee
19 changed files with 722 additions and 245 deletions

179
t5/t5.py
View File

@@ -1,12 +1,45 @@
import argparse
import json
from pathlib import Path
from time import perf_counter_ns
from types import SimpleNamespace
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import AutoTokenizer, T5Config
from transformers import AutoTokenizer
class Tokenizer:
def __init__(self, config, model_name):
self._decoder_start_id = config.decoder_start_token_id
self._tokenizer = AutoTokenizer.from_pretrained(
model_name,
legacy=False,
model_max_length=getattr(config, "n_positions", 512),
)
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
)
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
def _relative_position_bucket(
@@ -60,10 +93,10 @@ def _relative_position_bucket(
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
def __init__(self, config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.max_distance = getattr(config, "relative_attention_max_distance", 128)
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.num_heads
@@ -91,7 +124,7 @@ class RelativePositionBias(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
@@ -135,17 +168,21 @@ class MultiHeadAttention(nn.Module):
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
self.gated = hasattr(config, "feed_forward_proj")
activation = (
"relu"
if not self.gated
else config.feed_forward_proj.removeprefix("gated-")
)
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
@@ -166,7 +203,7 @@ class DenseActivation(nn.Module):
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
@@ -184,7 +221,7 @@ class TransformerEncoderLayer(nn.Module):
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
@@ -200,7 +237,7 @@ class TransformerEncoder(nn.Module):
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
@@ -233,7 +270,7 @@ class TransformerDecoderLayer(nn.Module):
class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
@@ -262,7 +299,7 @@ class TransformerDecoder(nn.Module):
class OutputHead(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
@@ -270,11 +307,11 @@ class OutputHead(nn.Module):
class T5(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config)
self.tie_word_embeddings = config.tie_word_embeddings
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
if not self.tie_word_embeddings:
self.lm_head = OutputHead(config)
self.model_dim = config.d_model
@@ -313,36 +350,82 @@ class T5(nn.Module):
):
return self.decode(decoder_inputs, self.encode(inputs))[0]
@classmethod
def sanitize(cls, weights):
shared_replacement_patterns = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
class Tokenizer:
def __init__(self, config: T5Config):
self._decoder_start_id = config.decoder_start_token_id
self._tokenizer = AutoTokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=getattr(config, "n_positions", 512),
)
encoder_replacement_patterns = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
decoder_replacement_patterns = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
ignored_keys = [
"decoder.layers.0.cross_attention.relative_attention_bias.weight"
]
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
)
def replace_key(key: str) -> str:
for old, new in shared_replacement_patterns:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in encoder_replacement_patterns:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in decoder_replacement_patterns:
key = key.replace(old, new)
return key
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
weights = {replace_key(k): v for k, v in weights.items()}
for key in ignored_keys:
if key in weights:
del weights[key]
return weights
@classmethod
def from_pretrained(
cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16
) -> tuple["T5", Tokenizer]:
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = T5(config)
weights = mx.load(str(path / "model.safetensors"))
weights = cls.sanitize(weights)
weights = {k: v.astype(dtype) for k, v in weights.items()}
model.load_weights(list(weights.items()))
return model, Tokenizer(config, "t5-base")
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
@@ -363,19 +446,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float]
yield y.squeeze()
def load_model(model_name: str, dtype: str = "float16"):
config = T5Config.from_pretrained(args.model)
dtype = getattr(mx, dtype)
model = T5(config)
file_name = model_name.replace("/", "-")
weights = mx.load(f"{file_name}.npz")
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model.update(weights)
mx.eval(model.parameters())
return model, Tokenizer(config)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="T5 Inference script")
parser.add_argument(
@@ -421,7 +491,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model, args.dtype)
dtype = getattr(mx, args.dtype)
model, tokenizer = T5.from_pretrained(args.model, dtype)
if args.encode_only:
print("[INFO] Encoding with T5...", flush=True)