Adds support for flan-t5

This commit is contained in:
Juarez Bochi 2023-12-18 18:05:40 -05:00
parent 688795c665
commit 930cd4d950
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 72 additions and 22 deletions

View File

@ -22,15 +22,13 @@ SHARED_REPLACEMENT_PATTERNS = [
ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.wi.", ".linear1."),
(".layer.1.DenseReluDense.wo.", ".linear2."),
(".layer.1.DenseReluDense.", ".dense."),
]
DECODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.wi.", ".linear1."),
(".layer.2.DenseReluDense.wo.", ".linear2."),
(".layer.2.DenseReluDense.", ".dense."),
]
@ -52,7 +50,8 @@ def convert(model_name):
replace_key(k): v.numpy().astype(np.float16)
for k, v in model.state_dict().items()
}
np.savez(f"{model_name}.npz", **weights)
file_name = model_name.replace("/", "-")
np.savez(f"{file_name}.npz", **weights)
if __name__ == "__main__":
@ -63,7 +62,19 @@ if __name__ == "__main__":
"--model",
type=str,
help="Name of the T5 model.",
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
choices=[
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
"google/flan-t5-ul2",
],
default="t5-small",
)
args = parser.parse_args()

View File

@ -150,15 +150,44 @@ class RMSNorm(nn.Module):
return self.weight * output
class TransformerEncoderLayer(nn.Module):
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
def __call__(self, x):
if self.gated:
hidden_act = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_act * hidden_linear
else:
x = self.act(self.wi(x))
return self.wo(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
self.dense = DenseActivation(config)
def __call__(self, x, mask):
y = self.ln1(x)
@ -166,9 +195,7 @@ class TransformerEncoderLayer(nn.Module):
x = x + y
y = self.ln2(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
y = self.dense(y)
return x + y
@ -191,14 +218,12 @@ class TransformerEncoder(nn.Module):
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
self.dense = DenseActivation(config)
def __call__(
self,
@ -217,9 +242,7 @@ class TransformerDecoderLayer(nn.Module):
x = x + y
y = self.ln3(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
y = self.dense(y)
x = x + y
return x, cache
@ -268,8 +291,9 @@ class T5(nn.Module):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config)
self.lm_head = OutputHead(config)
self.tie_word_embeddings = config.tie_word_embeddings
if not self.tie_word_embeddings:
self.lm_head = OutputHead(config)
self.model_dim = config.d_model
def encode(self, inputs: mx.array):
@ -292,9 +316,12 @@ class T5(nn.Module):
y, cache = self.decoder(
inputs, memory=memory, mask=mask, memory_mask=None, cache=cache
)
if self.tie_word_embeddings:
if not self.tie_word_embeddings:
y *= self.model_dim**-0.5
return self.lm_head(y), cache
y = self.lm_head(y)
else:
y = y @ self.wte.weight.T
return y, cache
def __call__(
self,
@ -371,7 +398,19 @@ if __name__ == "__main__":
"--model",
type=str,
help="Name of the T5 model.",
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
choices=[
"t5-small",
"t5-base",
"t5-large",
"t5-3b",
"t5-11b",
"google/flan-t5-small",
"google/flan-t5-base",
"google/flan-t5-large",
"google/flan-t5-xl",
"google/flan-t5-xxl",
"google/flan-t5-ul2",
],
default="t5-small",
)
parser.add_argument(