mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Load config from HF to support any model
This commit is contained in:
parent
b2a3782a96
commit
55f204dd3a
@ -51,7 +51,7 @@ def convert(model_name, half_precision=False):
|
|||||||
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()}
|
||||||
if half_precision:
|
if half_precision:
|
||||||
weights = {k: v.astype(np.float16) for k, v in weights.items()}
|
weights = {k: v.astype(np.float16) for k, v in weights.items()}
|
||||||
np.savez("weights.npz", **weights)
|
np.savez(f"{model_name}.npz", **weights)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the T5 model.",
|
help="Name of the T5 model.",
|
||||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||||
@ -71,4 +71,4 @@ if __name__ == "__main__":
|
|||||||
help="Convert weights to half precision (float16).",
|
help="Convert weights to half precision (float16).",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert(args.model_name, args.half_precision)
|
convert(args.model, args.half_precision)
|
||||||
|
60
t5/t5.py
60
t5/t5.py
@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
from time import perf_counter_ns
|
from time import perf_counter_ns
|
||||||
|
|
||||||
@ -7,25 +6,7 @@ import numpy as np
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
from transformers import T5Tokenizer
|
from transformers import T5Config, T5Tokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelArgs:
|
|
||||||
d_ff: int = 2048
|
|
||||||
d_kv: int = 64
|
|
||||||
d_model: int = 512
|
|
||||||
dropout_rate: int = 0.1
|
|
||||||
layer_norm_epsilon: float = 1e-06
|
|
||||||
n_positions: int = 512
|
|
||||||
relative_attention_num_buckets: int = 32
|
|
||||||
relative_attention_max_distance: int = 128
|
|
||||||
num_heads: int = 8
|
|
||||||
num_layers: int = 6
|
|
||||||
decoder_start_token_id: int = 0
|
|
||||||
eos_token_id: int = 1
|
|
||||||
pad_token_id: int = 0
|
|
||||||
vocab_size: int = 32128
|
|
||||||
|
|
||||||
|
|
||||||
def _relative_position_bucket(
|
def _relative_position_bucket(
|
||||||
@ -110,7 +91,7 @@ class RelativePositionBias(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_heads
|
self.num_heads = config.num_heads
|
||||||
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
@ -167,7 +148,7 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
self.attention = MultiHeadAttention(config)
|
self.attention = MultiHeadAttention(config)
|
||||||
@ -189,7 +170,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
class TransformerEncoder(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||||
@ -205,7 +186,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Module):
|
class TransformerDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
self.self_attention = MultiHeadAttention(config)
|
self.self_attention = MultiHeadAttention(config)
|
||||||
@ -242,7 +223,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.Module):
|
class TransformerDecoder(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
||||||
@ -272,7 +253,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class OutputHead(nn.Module):
|
class OutputHead(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: T5Config):
|
||||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(self, inputs):
|
def __call__(self, inputs):
|
||||||
@ -280,7 +261,7 @@ class OutputHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class T5(nn.Module):
|
class T5(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: T5Config):
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
self.encoder = TransformerEncoder(config)
|
self.encoder = TransformerEncoder(config)
|
||||||
self.decoder = TransformerDecoder(config)
|
self.decoder = TransformerDecoder(config)
|
||||||
@ -334,9 +315,9 @@ def generate(
|
|||||||
yield y.squeeze()
|
yield y.squeeze()
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_config):
|
def load_model(model_name: str, config: T5Config):
|
||||||
model = T5(model_config)
|
model = T5(config)
|
||||||
weights = mx.load("weights.npz")
|
weights = mx.load(f"{model_name}.npz")
|
||||||
current_weights = tree_flatten(model.parameters())
|
current_weights = tree_flatten(model.parameters())
|
||||||
weights_to_load = list(weights.items())
|
weights_to_load = list(weights.items())
|
||||||
current_weights_dict = dict(current_weights)
|
current_weights_dict = dict(current_weights)
|
||||||
@ -353,12 +334,18 @@ def load_model(model_config):
|
|||||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||||
model.update(tree_unflatten(weights_to_load))
|
model.update(tree_unflatten(weights_to_load))
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)
|
return model
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="T5 Inference script")
|
parser = argparse.ArgumentParser(description="T5 Inference script")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
help="Name of the T5 model.",
|
||||||
|
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||||
|
default="t5-small",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
help="",
|
help="",
|
||||||
@ -388,8 +375,13 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
mx.random.seed(args.seed)
|
mx.random.seed(args.seed)
|
||||||
|
|
||||||
config = ModelArgs()
|
config = T5Config.from_pretrained(args.model)
|
||||||
model, tokenizer = load_model(config)
|
model = load_model(args.model, config)
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
legacy=False,
|
||||||
|
model_max_length=config.n_positions,
|
||||||
|
)
|
||||||
|
|
||||||
prompt = tokenizer(
|
prompt = tokenizer(
|
||||||
args.prompt,
|
args.prompt,
|
||||||
|
Loading…
Reference in New Issue
Block a user