mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
MusicGen (#1020)
* Add MusicGen model * add benchmarks * change to from_pretrained * symlinks * add readme and requirements * fix readme * readme
This commit is contained in:
179
t5/t5.py
179
t5/t5.py
@@ -1,12 +1,45 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from time import perf_counter_ns
|
||||
from types import SimpleNamespace
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from transformers import AutoTokenizer, T5Config
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, config, model_name):
|
||||
self._decoder_start_id = config.decoder_start_token_id
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
legacy=False,
|
||||
model_max_length=getattr(config, "n_positions", 512),
|
||||
)
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def decoder_start_id(self) -> int:
|
||||
return self._decoder_start_id
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array(
|
||||
self._tokenizer(
|
||||
s,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||
|
||||
|
||||
def _relative_position_bucket(
|
||||
@@ -60,10 +93,10 @@ def _relative_position_bucket(
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, config: T5Config, bidirectional: bool):
|
||||
def __init__(self, config, bidirectional: bool):
|
||||
self.bidirectional = bidirectional
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.max_distance = getattr(config, "relative_attention_max_distance", 128)
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(
|
||||
config.relative_attention_num_buckets, config.num_heads
|
||||
@@ -91,7 +124,7 @@ class RelativePositionBias(nn.Module):
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
inner_dim = config.d_kv * config.num_heads
|
||||
self.num_heads = config.num_heads
|
||||
@@ -135,17 +168,21 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
|
||||
class DenseActivation(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.gated = config.feed_forward_proj.startswith("gated")
|
||||
self.gated = hasattr(config, "feed_forward_proj")
|
||||
activation = (
|
||||
"relu"
|
||||
if not self.gated
|
||||
else config.feed_forward_proj.removeprefix("gated-")
|
||||
)
|
||||
if self.gated:
|
||||
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
else:
|
||||
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||
if activation == "relu":
|
||||
self.act = nn.relu
|
||||
elif activation == "gelu":
|
||||
@@ -166,7 +203,7 @@ class DenseActivation(nn.Module):
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
@@ -184,7 +221,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
@@ -200,7 +237,7 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
@@ -233,7 +270,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||
@@ -262,7 +299,7 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs):
|
||||
@@ -270,11 +307,11 @@ class OutputHead(nn.Module):
|
||||
|
||||
|
||||
class T5(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
def __init__(self, config):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
self.decoder = TransformerDecoder(config)
|
||||
self.tie_word_embeddings = config.tie_word_embeddings
|
||||
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
|
||||
if not self.tie_word_embeddings:
|
||||
self.lm_head = OutputHead(config)
|
||||
self.model_dim = config.d_model
|
||||
@@ -313,36 +350,82 @@ class T5(nn.Module):
|
||||
):
|
||||
return self.decode(decoder_inputs, self.encode(inputs))[0]
|
||||
|
||||
@classmethod
|
||||
def sanitize(cls, weights):
|
||||
shared_replacement_patterns = [
|
||||
(".block.", ".layers."),
|
||||
(".k.", ".key_proj."),
|
||||
(".o.", ".out_proj."),
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, config: T5Config):
|
||||
self._decoder_start_id = config.decoder_start_token_id
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model,
|
||||
legacy=False,
|
||||
model_max_length=getattr(config, "n_positions", 512),
|
||||
)
|
||||
encoder_replacement_patterns = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
@property
|
||||
def eos_id(self) -> int:
|
||||
return self._tokenizer.eos_token_id
|
||||
decoder_replacement_patterns = [
|
||||
(".layer.0.SelfAttention.", ".self_attention."),
|
||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||
(".layer.2.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
@property
|
||||
def decoder_start_id(self) -> int:
|
||||
return self._decoder_start_id
|
||||
ignored_keys = [
|
||||
"decoder.layers.0.cross_attention.relative_attention_bias.weight"
|
||||
]
|
||||
|
||||
def encode(self, s: str) -> mx.array:
|
||||
return mx.array(
|
||||
self._tokenizer(
|
||||
s,
|
||||
return_tensors="np",
|
||||
return_attention_mask=False,
|
||||
)["input_ids"]
|
||||
)
|
||||
def replace_key(key: str) -> str:
|
||||
for old, new in shared_replacement_patterns:
|
||||
key = key.replace(old, new)
|
||||
if key.startswith("encoder."):
|
||||
for old, new in encoder_replacement_patterns:
|
||||
key = key.replace(old, new)
|
||||
elif key.startswith("decoder."):
|
||||
for old, new in decoder_replacement_patterns:
|
||||
key = key.replace(old, new)
|
||||
return key
|
||||
|
||||
def decode(self, t: List[int], with_sep: bool = True) -> str:
|
||||
tokens = self._tokenizer.convert_ids_to_tokens(t)
|
||||
return "".join(t.replace("▁", " " if with_sep else "") for t in tokens)
|
||||
weights = {replace_key(k): v for k, v in weights.items()}
|
||||
for key in ignored_keys:
|
||||
if key in weights:
|
||||
del weights[key]
|
||||
return weights
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16
|
||||
) -> tuple["T5", Tokenizer]:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
path = Path(path_or_repo)
|
||||
if not path.exists():
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=path_or_repo,
|
||||
allow_patterns=["*.json", "*.safetensors", "*.model"],
|
||||
)
|
||||
)
|
||||
|
||||
with open(path / "config.json", "r") as f:
|
||||
config = SimpleNamespace(**json.load(f))
|
||||
|
||||
model = T5(config)
|
||||
weights = mx.load(str(path / "model.safetensors"))
|
||||
weights = cls.sanitize(weights)
|
||||
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
||||
model.load_weights(list(weights.items()))
|
||||
return model, Tokenizer(config, "t5-base")
|
||||
|
||||
|
||||
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
|
||||
@@ -363,19 +446,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float]
|
||||
yield y.squeeze()
|
||||
|
||||
|
||||
def load_model(model_name: str, dtype: str = "float16"):
|
||||
config = T5Config.from_pretrained(args.model)
|
||||
dtype = getattr(mx, dtype)
|
||||
model = T5(config)
|
||||
file_name = model_name.replace("/", "-")
|
||||
weights = mx.load(f"{file_name}.npz")
|
||||
weights = tree_unflatten(list(weights.items()))
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
return model, Tokenizer(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="T5 Inference script")
|
||||
parser.add_argument(
|
||||
@@ -421,7 +491,8 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
model, tokenizer = load_model(args.model, args.dtype)
|
||||
dtype = getattr(mx, args.dtype)
|
||||
model, tokenizer = T5.from_pretrained(args.model, dtype)
|
||||
|
||||
if args.encode_only:
|
||||
print("[INFO] Encoding with T5...", flush=True)
|
||||
|
Reference in New Issue
Block a user