mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Adds support for flan-t5
This commit is contained in:
parent
688795c665
commit
930cd4d950
@ -22,15 +22,13 @@ SHARED_REPLACEMENT_PATTERNS = [
|
|||||||
|
|
||||||
ENCODER_REPLACEMENT_PATTERNS = [
|
ENCODER_REPLACEMENT_PATTERNS = [
|
||||||
(".layer.0.SelfAttention.", ".attention."),
|
(".layer.0.SelfAttention.", ".attention."),
|
||||||
(".layer.1.DenseReluDense.wi.", ".linear1."),
|
(".layer.1.DenseReluDense.", ".dense."),
|
||||||
(".layer.1.DenseReluDense.wo.", ".linear2."),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
DECODER_REPLACEMENT_PATTERNS = [
|
DECODER_REPLACEMENT_PATTERNS = [
|
||||||
(".layer.0.SelfAttention.", ".self_attention."),
|
(".layer.0.SelfAttention.", ".self_attention."),
|
||||||
(".layer.1.EncDecAttention.", ".cross_attention."),
|
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||||
(".layer.2.DenseReluDense.wi.", ".linear1."),
|
(".layer.2.DenseReluDense.", ".dense."),
|
||||||
(".layer.2.DenseReluDense.wo.", ".linear2."),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +50,8 @@ def convert(model_name):
|
|||||||
replace_key(k): v.numpy().astype(np.float16)
|
replace_key(k): v.numpy().astype(np.float16)
|
||||||
for k, v in model.state_dict().items()
|
for k, v in model.state_dict().items()
|
||||||
}
|
}
|
||||||
np.savez(f"{model_name}.npz", **weights)
|
file_name = model_name.replace("/", "-")
|
||||||
|
np.savez(f"{file_name}.npz", **weights)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -63,7 +62,19 @@ if __name__ == "__main__":
|
|||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the T5 model.",
|
help="Name of the T5 model.",
|
||||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
choices=[
|
||||||
|
"t5-small",
|
||||||
|
"t5-base",
|
||||||
|
"t5-large",
|
||||||
|
"t5-3b",
|
||||||
|
"t5-11b",
|
||||||
|
"google/flan-t5-small",
|
||||||
|
"google/flan-t5-base",
|
||||||
|
"google/flan-t5-large",
|
||||||
|
"google/flan-t5-xl",
|
||||||
|
"google/flan-t5-xxl",
|
||||||
|
"google/flan-t5-ul2",
|
||||||
|
],
|
||||||
default="t5-small",
|
default="t5-small",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
71
t5/t5.py
71
t5/t5.py
@ -150,15 +150,44 @@ class RMSNorm(nn.Module):
|
|||||||
return self.weight * output
|
return self.weight * output
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class DenseActivation(nn.Module):
|
||||||
def __init__(self, config: T5Config):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
|
self.gated = config.feed_forward_proj.startswith("gated")
|
||||||
|
if self.gated:
|
||||||
|
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
|
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
|
else:
|
||||||
|
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
|
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||||
|
activation = config.feed_forward_proj.removeprefix("gated-")
|
||||||
|
if activation == "relu":
|
||||||
|
self.act = nn.relu
|
||||||
|
elif activation == "gelu":
|
||||||
|
self.act = nn.gelu
|
||||||
|
elif activation == "silu":
|
||||||
|
self.act = nn.silu
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown activation: {activation}")
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
if self.gated:
|
||||||
|
hidden_act = self.act(self.wi_0(x))
|
||||||
|
hidden_linear = self.wi_1(x)
|
||||||
|
x = hidden_act * hidden_linear
|
||||||
|
else:
|
||||||
|
x = self.act(self.wi(x))
|
||||||
|
return self.wo(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: T5Config):
|
||||||
|
super().__init__()
|
||||||
self.attention = MultiHeadAttention(config)
|
self.attention = MultiHeadAttention(config)
|
||||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
self.dense = DenseActivation(config)
|
||||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
y = self.ln1(x)
|
y = self.ln1(x)
|
||||||
@ -166,9 +195,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
y = self.ln2(x)
|
y = self.ln2(x)
|
||||||
y = self.linear1(y)
|
y = self.dense(y)
|
||||||
y = mx.maximum(y, 0)
|
|
||||||
y = self.linear2(y)
|
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
@ -191,14 +218,12 @@ class TransformerEncoder(nn.Module):
|
|||||||
class TransformerDecoderLayer(nn.Module):
|
class TransformerDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: T5Config):
|
def __init__(self, config: T5Config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
|
||||||
self.self_attention = MultiHeadAttention(config)
|
self.self_attention = MultiHeadAttention(config)
|
||||||
self.cross_attention = MultiHeadAttention(config)
|
self.cross_attention = MultiHeadAttention(config)
|
||||||
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
self.dense = DenseActivation(config)
|
||||||
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -217,9 +242,7 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
y = self.ln3(x)
|
y = self.ln3(x)
|
||||||
y = self.linear1(y)
|
y = self.dense(y)
|
||||||
y = mx.maximum(y, 0)
|
|
||||||
y = self.linear2(y)
|
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
return x, cache
|
return x, cache
|
||||||
@ -268,8 +291,9 @@ class T5(nn.Module):
|
|||||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
self.encoder = TransformerEncoder(config)
|
self.encoder = TransformerEncoder(config)
|
||||||
self.decoder = TransformerDecoder(config)
|
self.decoder = TransformerDecoder(config)
|
||||||
self.lm_head = OutputHead(config)
|
|
||||||
self.tie_word_embeddings = config.tie_word_embeddings
|
self.tie_word_embeddings = config.tie_word_embeddings
|
||||||
|
if not self.tie_word_embeddings:
|
||||||
|
self.lm_head = OutputHead(config)
|
||||||
self.model_dim = config.d_model
|
self.model_dim = config.d_model
|
||||||
|
|
||||||
def encode(self, inputs: mx.array):
|
def encode(self, inputs: mx.array):
|
||||||
@ -292,9 +316,12 @@ class T5(nn.Module):
|
|||||||
y, cache = self.decoder(
|
y, cache = self.decoder(
|
||||||
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
|
||||||
)
|
)
|
||||||
if self.tie_word_embeddings:
|
if not self.tie_word_embeddings:
|
||||||
y *= self.model_dim**-0.5
|
y *= self.model_dim**-0.5
|
||||||
return self.lm_head(y), cache
|
y = self.lm_head(y)
|
||||||
|
else:
|
||||||
|
y = y @ self.wte.weight.T
|
||||||
|
return y, cache
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -371,7 +398,19 @@ if __name__ == "__main__":
|
|||||||
"--model",
|
"--model",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of the T5 model.",
|
help="Name of the T5 model.",
|
||||||
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
|
choices=[
|
||||||
|
"t5-small",
|
||||||
|
"t5-base",
|
||||||
|
"t5-large",
|
||||||
|
"t5-3b",
|
||||||
|
"t5-11b",
|
||||||
|
"google/flan-t5-small",
|
||||||
|
"google/flan-t5-base",
|
||||||
|
"google/flan-t5-large",
|
||||||
|
"google/flan-t5-xl",
|
||||||
|
"google/flan-t5-xxl",
|
||||||
|
"google/flan-t5-ul2",
|
||||||
|
],
|
||||||
default="t5-small",
|
default="t5-small",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
Loading…
Reference in New Issue
Block a user