diff --git a/t5/hf_t5.py b/t5/hf_t5.py index ce01e355..12329e4b 100644 --- a/t5/hf_t5.py +++ b/t5/hf_t5.py @@ -1,6 +1,6 @@ import argparse -from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration +from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM def embed(t5_model: str): @@ -25,7 +25,7 @@ def embed(t5_model: str): def generate(t5_model: str): prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast." tokenizer = AutoTokenizer.from_pretrained(t5_model) - torch_model = T5ForConditionalGeneration.from_pretrained(t5_model) + torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model) torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) diff --git a/t5/t5.py b/t5/t5.py index 29ec4291..45855136 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -6,7 +6,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.utils import tree_map, tree_unflatten -from transformers import T5Config, T5Tokenizer +from transformers import T5Config, AutoTokenizer def _relative_position_bucket( @@ -251,8 +251,9 @@ class TransformerDecoderLayer(nn.Module): class TransformerDecoder(nn.Module): def __init__(self, config: T5Config): super().__init__() + n_layers = getattr(config, "num_decoder_layers", config.num_layers) self.layers = [ - TransformerDecoderLayer(config) for i in range(config.num_layers) + TransformerDecoderLayer(config) for i in range(n_layers) ] self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) @@ -332,9 +333,9 @@ class T5(nn.Module): class Tokenizer: - def __init__(self, model_name: str, config: T5Config): + def __init__(self, config: T5Config): self._decoder_start_id = config.decoder_start_token_id - self._tokenizer = T5Tokenizer.from_pretrained( + self._tokenizer = AutoTokenizer.from_pretrained( args.model, legacy=False, model_max_length=getattr(config, "n_positions", 512), @@ -390,7 +391,7 @@ def load_model(model_name: str, dtype: str = "float16"): weights = tree_map(lambda p: p.astype(dtype), weights) model.update(weights) mx.eval(model.parameters()) - return model, Tokenizer(args.model, config) + return model, Tokenizer(config) if __name__ == "__main__":