mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Adds support for flan-t5
This commit is contained in:
parent
688795c665
commit
930cd4d950
@ -22,15 +22,13 @@ SHARED_REPLACEMENT_PATTERNS = [
|
||||
|
||||
ENCODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".attention."),
|
||||
(".layer.1.DenseReluDense.wi.", ".linear1."),
|
||||
(".layer.1.DenseReluDense.wo.", ".linear2."),
|
||||
(".layer.1.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
DECODER_REPLACEMENT_PATTERNS = [
|
||||
(".layer.0.SelfAttention.", ".self_attention."),
|
||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||
(".layer.2.DenseReluDense.wi.", ".linear1."),
|
||||
(".layer.2.DenseReluDense.wo.", ".linear2."),
|
||||
(".layer.2.DenseReluDense.", ".dense."),
|
||||
]
|
||||
|
||||
|
||||
@ -52,7 +50,8 @@ def convert(model_name):
|
||||
replace_key(k): v.numpy().astype(np.float16)
|
||||
for k, v in model.state_dict().items()
|
||||
}
|
||||
np.savez(f"{model_name}.npz", **weights)
|
||||
file_name = model_name.replace("/", "-")
|
||||
np.savez(f"{file_name}.npz", **weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -63,7 +62,19 @@ if __name__ == "__main__":
|
||||
"--model",
|
||||
type=str,
|
||||
help="Name of the T5 model.",
|
||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||
choices=[
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
"google/flan-t5-small",
|
||||
"google/flan-t5-base",
|
||||
"google/flan-t5-large",
|
||||
"google/flan-t5-xl",
|
||||
"google/flan-t5-xxl",
|
||||
"google/flan-t5-ul2",
|
||||
],
|
||||
default="t5-small",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
71
t5/t5.py
71
t5/t5.py
@ -150,15 +150,44 @@ class RMSNorm(nn.Module):
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
class DenseActivation(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.gated = config.feed_forward_proj.startswith("gated")
|
||||
if self.gated:
|
||||
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
else:
|
||||
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||
if activation == "relu":
|
||||
self.act = nn.relu
|
||||
elif activation == "gelu":
|
||||
self.act = nn.gelu
|
||||
elif activation == "silu":
|
||||
self.act = nn.silu
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
|
||||
def __call__(self, x):
|
||||
if self.gated:
|
||||
hidden_act = self.act(self.wi_0(x))
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_act * hidden_linear
|
||||
else:
|
||||
x = self.act(self.wi(x))
|
||||
return self.wo(x)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadAttention(config)
|
||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
@ -166,9 +195,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.linear1(y)
|
||||
y = mx.maximum(y, 0)
|
||||
y = self.linear2(y)
|
||||
y = self.dense(y)
|
||||
return x + y
|
||||
|
||||
|
||||
@ -191,14 +218,12 @@ class TransformerEncoder(nn.Module):
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -217,9 +242,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
x = x + y
|
||||
|
||||
y = self.ln3(x)
|
||||
y = self.linear1(y)
|
||||
y = mx.maximum(y, 0)
|
||||
y = self.linear2(y)
|
||||
y = self.dense(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
@ -268,8 +291,9 @@ class T5(nn.Module):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
self.decoder = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
self.tie_word_embeddings = config.tie_word_embeddings
|
||||
if not self.tie_word_embeddings:
|
||||
self.lm_head = OutputHead(config)
|
||||
self.model_dim = config.d_model
|
||||
|
||||
def encode(self, inputs: mx.array):
|
||||
@ -292,9 +316,12 @@ class T5(nn.Module):
|
||||
y, cache = self.decoder(
|
||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||
)
|
||||
if self.tie_word_embeddings:
|
||||
if not self.tie_word_embeddings:
|
||||
y *= self.model_dim**-0.5
|
||||
return self.lm_head(y), cache
|
||||
y = self.lm_head(y)
|
||||
else:
|
||||
y = y @ self.wte.weight.T
|
||||
return y, cache
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -371,7 +398,19 @@ if __name__ == "__main__":
|
||||
"--model",
|
||||
type=str,
|
||||
help="Name of the T5 model.",
|
||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
||||
choices=[
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
"google/flan-t5-small",
|
||||
"google/flan-t5-base",
|
||||
"google/flan-t5-large",
|
||||
"google/flan-t5-xl",
|
||||
"google/flan-t5-xxl",
|
||||
"google/flan-t5-ul2",
|
||||
],
|
||||
default="t5-small",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
Loading…
Reference in New Issue
Block a user