Load config from HF to support any model

This commit is contained in:
Juarez Bochi 2023-12-18 08:42:06 -05:00
parent b2a3782a96
commit 55f204dd3a
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 29 additions and 37 deletions

View File

@ -51,7 +51,7 @@ def convert(model_name, half_precision=False):
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()} weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
if half_precision: if half_precision:
weights = {k: v.astype(np.float16) for k, v in weights.items()} weights = {k: v.astype(np.float16) for k, v in weights.items()}
np.savez("weights.npz", **weights) np.savez(f"{model_name}.npz", **weights)
if __name__ == "__main__": if __name__ == "__main__":
@ -59,7 +59,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX") parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
parser.add_argument( parser.add_argument(
"--model_name", "--model",
type=str, type=str,
help="Name of the T5 model.", help="Name of the T5 model.",
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
@ -71,4 +71,4 @@ if __name__ == "__main__":
help="Convert weights to half precision (float16).", help="Convert weights to half precision (float16).",
) )
args = parser.parse_args() args = parser.parse_args()
convert(args.model_name, args.half_precision) convert(args.model, args.half_precision)

View File

@ -1,5 +1,4 @@
import argparse import argparse
from dataclasses import dataclass
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
from time import perf_counter_ns from time import perf_counter_ns
@ -7,25 +6,7 @@ import numpy as np
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from transformers import T5Tokenizer from transformers import T5Config, T5Tokenizer
@dataclass
class ModelArgs:
d_ff: int = 2048
d_kv: int = 64
d_model: int = 512
dropout_rate: int = 0.1
layer_norm_epsilon: float = 1e-06
n_positions: int = 512
relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
num_heads: int = 8
num_layers: int = 6
decoder_start_token_id: int = 0
eos_token_id: int = 1
pad_token_id: int = 0
vocab_size: int = 32128
def _relative_position_bucket( def _relative_position_bucket(
@ -110,7 +91,7 @@ class RelativePositionBias(nn.Module):
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: T5Config):
super().__init__() super().__init__()
self.num_heads = config.num_heads self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
@ -167,7 +148,7 @@ class RMSNorm(nn.Module):
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: T5Config):
super().__init__() super().__init__()
mlp_dims = config.d_ff or config.d_model * 4 mlp_dims = config.d_ff or config.d_model * 4
self.attention = MultiHeadAttention(config) self.attention = MultiHeadAttention(config)
@ -189,7 +170,7 @@ class TransformerEncoderLayer(nn.Module):
class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: T5Config):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers) TransformerEncoderLayer(config) for i in range(config.num_layers)
@ -205,7 +186,7 @@ class TransformerEncoder(nn.Module):
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: T5Config):
super().__init__() super().__init__()
mlp_dims = config.d_ff or config.d_model * 4 mlp_dims = config.d_ff or config.d_model * 4
self.self_attention = MultiHeadAttention(config) self.self_attention = MultiHeadAttention(config)
@ -242,7 +223,7 @@ class TransformerDecoderLayer(nn.Module):
class TransformerDecoder(nn.Module): class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: T5Config):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerDecoderLayer(config) for i in range(config.num_layers) TransformerDecoderLayer(config) for i in range(config.num_layers)
@ -272,7 +253,7 @@ class TransformerDecoder(nn.Module):
class OutputHead(nn.Module): class OutputHead(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: T5Config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs): def __call__(self, inputs):
@ -280,7 +261,7 @@ class OutputHead(nn.Module):
class T5(nn.Module): class T5(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model) self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config) self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config) self.decoder = TransformerDecoder(config)
@ -334,9 +315,9 @@ def generate(
yield y.squeeze() yield y.squeeze()
def load_model(model_config): def load_model(model_name: str, config: T5Config):
model = T5(model_config) model = T5(config)
weights = mx.load("weights.npz") weights = mx.load(f"{model_name}.npz")
current_weights = tree_flatten(model.parameters()) current_weights = tree_flatten(model.parameters())
weights_to_load = list(weights.items()) weights_to_load = list(weights.items())
current_weights_dict = dict(current_weights) current_weights_dict = dict(current_weights)
@ -353,12 +334,18 @@ def load_model(model_config):
print("Loading shape: ", weights_to_load_dict[key].shape) print("Loading shape: ", weights_to_load_dict[key].shape)
model.update(tree_unflatten(weights_to_load)) model.update(tree_unflatten(weights_to_load))
mx.eval(model.parameters()) mx.eval(model.parameters())
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False) return model
return model, tokenizer
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="T5 Inference script") parser = argparse.ArgumentParser(description="T5 Inference script")
parser.add_argument(
"--model",
type=str,
help="Name of the T5 model.",
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
default="t5-small",
)
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
help="", help="",
@ -388,8 +375,13 @@ if __name__ == "__main__":
mx.random.seed(args.seed) mx.random.seed(args.seed)
config = ModelArgs() config = T5Config.from_pretrained(args.model)
model, tokenizer = load_model(config) model = load_model(args.model, config)
tokenizer = T5Tokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=config.n_positions,
)
prompt = tokenizer( prompt = tokenizer(
args.prompt, args.prompt,