mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add support for byt5 models (#161)
* Add support for byt5 models * Remove unused import
This commit is contained in:
parent
6c574dbecf
commit
4c9db80ed2
@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
|
||||
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
|
||||
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM
|
||||
|
||||
|
||||
def embed(t5_model: str):
|
||||
@ -25,7 +25,7 @@ def embed(t5_model: str):
|
||||
def generate(t5_model: str):
|
||||
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
|
||||
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
|
||||
torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model)
|
||||
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
|
||||
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
|
11
t5/t5.py
11
t5/t5.py
@ -6,7 +6,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_map, tree_unflatten
|
||||
from transformers import T5Config, T5Tokenizer
|
||||
from transformers import T5Config, AutoTokenizer
|
||||
|
||||
|
||||
def _relative_position_bucket(
|
||||
@ -251,8 +251,9 @@ class TransformerDecoderLayer(nn.Module):
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
||||
TransformerDecoderLayer(config) for i in range(n_layers)
|
||||
]
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||
@ -332,9 +333,9 @@ class T5(nn.Module):
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, model_name: str, config: T5Config):
|
||||
def __init__(self, config: T5Config):
|
||||
self._decoder_start_id = config.decoder_start_token_id
|
||||
self._tokenizer = T5Tokenizer.from_pretrained(
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.model,
|
||||
legacy=False,
|
||||
model_max_length=getattr(config, "n_positions", 512),
|
||||
@ -390,7 +391,7 @@ def load_model(model_name: str, dtype: str = "float16"):
|
||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||
model.update(weights)
|
||||
mx.eval(model.parameters())
|
||||
return model, Tokenizer(args.model, config)
|
||||
return model, Tokenizer(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
Reference in New Issue
Block a user