Add support for byt5 models (#161)

* Add support for byt5 models

* Remove unused import
This commit is contained in:
Juarez Bochi 2023-12-21 11:46:36 -05:00 committed by GitHub
parent 6c574dbecf
commit 4c9db80ed2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 7 deletions

View File

@ -1,6 +1,6 @@
import argparse
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM
def embed(t5_model: str):
@ -25,7 +25,7 @@ def embed(t5_model: str):
def generate(t5_model: str):
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
tokenizer = AutoTokenizer.from_pretrained(t5_model)
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model)
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

View File

@ -6,7 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import T5Config, T5Tokenizer
from transformers import T5Config, AutoTokenizer
def _relative_position_bucket(
@ -251,8 +251,9 @@ class TransformerDecoderLayer(nn.Module):
class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [
TransformerDecoderLayer(config) for i in range(config.num_layers)
TransformerDecoderLayer(config) for i in range(n_layers)
]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
@ -332,9 +333,9 @@ class T5(nn.Module):
class Tokenizer:
def __init__(self, model_name: str, config: T5Config):
def __init__(self, config: T5Config):
self._decoder_start_id = config.decoder_start_token_id
self._tokenizer = T5Tokenizer.from_pretrained(
self._tokenizer = AutoTokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=getattr(config, "n_positions", 512),
@ -390,7 +391,7 @@ def load_model(model_name: str, dtype: str = "float16"):
weights = tree_map(lambda p: p.astype(dtype), weights)
model.update(weights)
mx.eval(model.parameters())
return model, Tokenizer(args.model, config)
return model, Tokenizer(config)
if __name__ == "__main__":