mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add support for byt5 models (#161)
* Add support for byt5 models * Remove unused import
This commit is contained in:
parent
6c574dbecf
commit
4c9db80ed2
@ -1,6 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
|
from transformers import AutoTokenizer, T5EncoderModel, AutoModelForSeq2SeqLM
|
||||||
|
|
||||||
|
|
||||||
def embed(t5_model: str):
|
def embed(t5_model: str):
|
||||||
@ -25,7 +25,7 @@ def embed(t5_model: str):
|
|||||||
def generate(t5_model: str):
|
def generate(t5_model: str):
|
||||||
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
|
prompt = "translate English to German: As much as six inches of rain could fall in the New York City region through Monday morning, and officials warned of flooding along the coast."
|
||||||
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
tokenizer = AutoTokenizer.from_pretrained(t5_model)
|
||||||
torch_model = T5ForConditionalGeneration.from_pretrained(t5_model)
|
torch_model = AutoModelForSeq2SeqLM.from_pretrained(t5_model)
|
||||||
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
|
torch_tokens = tokenizer(prompt, return_tensors="pt", padding=True).input_ids
|
||||||
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
|
outputs = torch_model.generate(torch_tokens, do_sample=False, max_length=512)
|
||||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||||
|
11
t5/t5.py
11
t5/t5.py
@ -6,7 +6,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.utils import tree_map, tree_unflatten
|
from mlx.utils import tree_map, tree_unflatten
|
||||||
from transformers import T5Config, T5Tokenizer
|
from transformers import T5Config, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def _relative_position_bucket(
|
def _relative_position_bucket(
|
||||||
@ -251,8 +251,9 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
class TransformerDecoder(nn.Module):
|
class TransformerDecoder(nn.Module):
|
||||||
def __init__(self, config: T5Config):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
TransformerDecoderLayer(config) for i in range(n_layers)
|
||||||
]
|
]
|
||||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||||
@ -332,9 +333,9 @@ class T5(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
def __init__(self, model_name: str, config: T5Config):
|
def __init__(self, config: T5Config):
|
||||||
self._decoder_start_id = config.decoder_start_token_id
|
self._decoder_start_id = config.decoder_start_token_id
|
||||||
self._tokenizer = T5Tokenizer.from_pretrained(
|
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||||
args.model,
|
args.model,
|
||||||
legacy=False,
|
legacy=False,
|
||||||
model_max_length=getattr(config, "n_positions", 512),
|
model_max_length=getattr(config, "n_positions", 512),
|
||||||
@ -390,7 +391,7 @@ def load_model(model_name: str, dtype: str = "float16"):
|
|||||||
weights = tree_map(lambda p: p.astype(dtype), weights)
|
weights = tree_map(lambda p: p.astype(dtype), weights)
|
||||||
model.update(weights)
|
model.update(weights)
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
return model, Tokenizer(args.model, config)
|
return model, Tokenizer(config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user