mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Load config from HF to support any model
This commit is contained in:
parent
b2a3782a96
commit
55f204dd3a
@ -51,7 +51,7 @@ def convert(model_name, half_precision=False):
|
||||
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
||||
if half_precision:
|
||||
weights = {k: v.astype(np.float16) for k, v in weights.items()}
|
||||
np.savez("weights.npz", **weights)
|
||||
np.savez(f"{model_name}.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
"--model",
|
||||
type=str,
|
||||
help="Name of the T5 model.",
|
||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||
@ -71,4 +71,4 @@ if __name__ == "__main__":
|
||||
help="Convert weights to half precision (float16).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert(args.model_name, args.half_precision)
|
||||
convert(args.model, args.half_precision)
|
||||
|
60
t5/t5.py
60
t5/t5.py
@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, List
|
||||
from time import perf_counter_ns
|
||||
|
||||
@ -7,25 +6,7 @@ import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from transformers import T5Tokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
d_ff: int = 2048
|
||||
d_kv: int = 64
|
||||
d_model: int = 512
|
||||
dropout_rate: int = 0.1
|
||||
layer_norm_epsilon: float = 1e-06
|
||||
n_positions: int = 512
|
||||
relative_attention_num_buckets: int = 32
|
||||
relative_attention_max_distance: int = 128
|
||||
num_heads: int = 8
|
||||
num_layers: int = 6
|
||||
decoder_start_token_id: int = 0
|
||||
eos_token_id: int = 1
|
||||
pad_token_id: int = 0
|
||||
vocab_size: int = 32128
|
||||
from transformers import T5Config, T5Tokenizer
|
||||
|
||||
|
||||
def _relative_position_bucket(
|
||||
@ -110,7 +91,7 @@ class RelativePositionBias(nn.Module):
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_heads
|
||||
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
@ -167,7 +148,7 @@ class RMSNorm(nn.Module):
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.attention = MultiHeadAttention(config)
|
||||
@ -189,7 +170,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
@ -205,7 +186,7 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
@ -242,7 +223,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
||||
@ -272,7 +253,7 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config: T5Config):
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs):
|
||||
@ -280,7 +261,7 @@ class OutputHead(nn.Module):
|
||||
|
||||
|
||||
class T5(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config: T5Config):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
self.decoder = TransformerDecoder(config)
|
||||
@ -334,9 +315,9 @@ def generate(
|
||||
yield y.squeeze()
|
||||
|
||||
|
||||
def load_model(model_config):
|
||||
model = T5(model_config)
|
||||
weights = mx.load("weights.npz")
|
||||
def load_model(model_name: str, config: T5Config):
|
||||
model = T5(config)
|
||||
weights = mx.load(f"{model_name}.npz")
|
||||
current_weights = tree_flatten(model.parameters())
|
||||
weights_to_load = list(weights.items())
|
||||
current_weights_dict = dict(current_weights)
|
||||
@ -353,12 +334,18 @@ def load_model(model_config):
|
||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||
model.update(tree_unflatten(weights_to_load))
|
||||
mx.eval(model.parameters())
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
|
||||
return model, tokenizer
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="T5 Inference script")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="Name of the T5 model.",
|
||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||
default="t5-small",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
help="",
|
||||
@ -388,8 +375,13 @@ if __name__ == "__main__":
|
||||
|
||||
mx.random.seed(args.seed)
|
||||
|
||||
config = ModelArgs()
|
||||
model, tokenizer = load_model(config)
|
||||
config = T5Config.from_pretrained(args.model)
|
||||
model = load_model(args.model, config)
|
||||
tokenizer = T5Tokenizer.from_pretrained(
|
||||
args.model,
|
||||
legacy=False,
|
||||
model_max_length=config.n_positions,
|
||||
)
|
||||
|
||||
prompt = tokenizer(
|
||||
args.prompt,
|
||||
|
Loading…
Reference in New Issue
Block a user