diff --git a/t5/convert.py b/t5/convert.py index c9377b5e..54c7b76b 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -51,7 +51,7 @@ def convert(model_name, half_precision=False): weights = {replace_key(k): v.numpy() for k, v in model.state_dict().items()} if half_precision: weights = {k: v.astype(np.float16) for k, v in weights.items()} - np.savez("weights.npz", **weights) + np.savez(f"{model_name}.npz", **weights) if __name__ == "__main__": @@ -59,7 +59,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert T5 weights to MLX") parser.add_argument( - "--model_name", + "--model", type=str, help="Name of the T5 model.", choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], @@ -71,4 +71,4 @@ if __name__ == "__main__": help="Convert weights to half precision (float16).", ) args = parser.parse_args() - convert(args.model_name, args.half_precision) + convert(args.model, args.half_precision) diff --git a/t5/t5.py b/t5/t5.py index 6f53694d..4863ba43 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,5 +1,4 @@ import argparse -from dataclasses import dataclass from typing import Optional, Tuple, List from time import perf_counter_ns @@ -7,25 +6,7 @@ import numpy as np import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten, tree_unflatten -from transformers import T5Tokenizer - - -@dataclass -class ModelArgs: - d_ff: int = 2048 - d_kv: int = 64 - d_model: int = 512 - dropout_rate: int = 0.1 - layer_norm_epsilon: float = 1e-06 - n_positions: int = 512 - relative_attention_num_buckets: int = 32 - relative_attention_max_distance: int = 128 - num_heads: int = 8 - num_layers: int = 6 - decoder_start_token_id: int = 0 - eos_token_id: int = 1 - pad_token_id: int = 0 - vocab_size: int = 32128 +from transformers import T5Config, T5Tokenizer def _relative_position_bucket( @@ -110,7 +91,7 @@ class RelativePositionBias(nn.Module): class MultiHeadAttention(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: T5Config): super().__init__() self.num_heads = config.num_heads self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False) @@ -167,7 +148,7 @@ class RMSNorm(nn.Module): class TransformerEncoderLayer(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: T5Config): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 self.attention = MultiHeadAttention(config) @@ -189,7 +170,7 @@ class TransformerEncoderLayer(nn.Module): class TransformerEncoder(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: T5Config): super().__init__() self.layers = [ TransformerEncoderLayer(config) for i in range(config.num_layers) @@ -205,7 +186,7 @@ class TransformerEncoder(nn.Module): class TransformerDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: T5Config): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 self.self_attention = MultiHeadAttention(config) @@ -242,7 +223,7 @@ class TransformerDecoderLayer(nn.Module): class TransformerDecoder(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: T5Config): super().__init__() self.layers = [ TransformerDecoderLayer(config) for i in range(config.num_layers) @@ -272,7 +253,7 @@ class TransformerDecoder(nn.Module): class OutputHead(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: T5Config): self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) def __call__(self, inputs): @@ -280,7 +261,7 @@ class OutputHead(nn.Module): class T5(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: T5Config): self.wte = nn.Embedding(config.vocab_size, config.d_model) self.encoder = TransformerEncoder(config) self.decoder = TransformerDecoder(config) @@ -334,9 +315,9 @@ def generate( yield y.squeeze() -def load_model(model_config): - model = T5(model_config) - weights = mx.load("weights.npz") +def load_model(model_name: str, config: T5Config): + model = T5(config) + weights = mx.load(f"{model_name}.npz") current_weights = tree_flatten(model.parameters()) weights_to_load = list(weights.items()) current_weights_dict = dict(current_weights) @@ -353,12 +334,18 @@ def load_model(model_config): print("Loading shape: ", weights_to_load_dict[key].shape) model.update(tree_unflatten(weights_to_load)) mx.eval(model.parameters()) - tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False) - return model, tokenizer + return model if __name__ == "__main__": parser = argparse.ArgumentParser(description="T5 Inference script") + parser.add_argument( + "--model", + type=str, + help="Name of the T5 model.", + choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], + default="t5-small", + ) parser.add_argument( "--prompt", help="", @@ -388,8 +375,13 @@ if __name__ == "__main__": mx.random.seed(args.seed) - config = ModelArgs() - model, tokenizer = load_model(config) + config = T5Config.from_pretrained(args.model) + model = load_model(args.model, config) + tokenizer = T5Tokenizer.from_pretrained( + args.model, + legacy=False, + model_max_length=config.n_positions, + ) prompt = tokenizer( args.prompt,