From 930cd4d9504ae152d74ad104d78bf61adce58039 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Mon, 18 Dec 2023 18:05:40 -0500 Subject: [PATCH] Adds support for flan-t5 --- t5/convert.py | 23 ++++++++++++----- t5/t5.py | 71 +++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 72 insertions(+), 22 deletions(-) diff --git a/t5/convert.py b/t5/convert.py index 77d5cfd9..8e1d327d 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -22,15 +22,13 @@ SHARED_REPLACEMENT_PATTERNS = [ ENCODER_REPLACEMENT_PATTERNS = [ (".layer.0.SelfAttention.", ".attention."), - (".layer.1.DenseReluDense.wi.", ".linear1."), - (".layer.1.DenseReluDense.wo.", ".linear2."), + (".layer.1.DenseReluDense.", ".dense."), ] DECODER_REPLACEMENT_PATTERNS = [ (".layer.0.SelfAttention.", ".self_attention."), (".layer.1.EncDecAttention.", ".cross_attention."), - (".layer.2.DenseReluDense.wi.", ".linear1."), - (".layer.2.DenseReluDense.wo.", ".linear2."), + (".layer.2.DenseReluDense.", ".dense."), ] @@ -52,7 +50,8 @@ def convert(model_name): replace_key(k): v.numpy().astype(np.float16) for k, v in model.state_dict().items() } - np.savez(f"{model_name}.npz", **weights) + file_name = model_name.replace("/", "-") + np.savez(f"{file_name}.npz", **weights) if __name__ == "__main__": @@ -63,7 +62,19 @@ if __name__ == "__main__": "--model", type=str, help="Name of the T5 model.", - choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], + choices=[ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + "google/flan-t5-small", + "google/flan-t5-base", + "google/flan-t5-large", + "google/flan-t5-xl", + "google/flan-t5-xxl", + "google/flan-t5-ul2", + ], default="t5-small", ) args = parser.parse_args() diff --git a/t5/t5.py b/t5/t5.py index 11476da4..d4b1c1db 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -150,15 +150,44 @@ class RMSNorm(nn.Module): return self.weight * output -class TransformerEncoderLayer(nn.Module): +class DenseActivation(nn.Module): def __init__(self, config: T5Config): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 + self.gated = config.feed_forward_proj.startswith("gated") + if self.gated: + self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) + else: + self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) + activation = config.feed_forward_proj.removeprefix("gated-") + if activation == "relu": + self.act = nn.relu + elif activation == "gelu": + self.act = nn.gelu + elif activation == "silu": + self.act = nn.silu + else: + raise ValueError(f"Unknown activation: {activation}") + + def __call__(self, x): + if self.gated: + hidden_act = self.act(self.wi_0(x)) + hidden_linear = self.wi_1(x) + x = hidden_act * hidden_linear + else: + x = self.act(self.wi(x)) + return self.wo(x) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() self.attention = MultiHeadAttention(config) self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) - self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) + self.dense = DenseActivation(config) def __call__(self, x, mask): y = self.ln1(x) @@ -166,9 +195,7 @@ class TransformerEncoderLayer(nn.Module): x = x + y y = self.ln2(x) - y = self.linear1(y) - y = mx.maximum(y, 0) - y = self.linear2(y) + y = self.dense(y) return x + y @@ -191,14 +218,12 @@ class TransformerEncoder(nn.Module): class TransformerDecoderLayer(nn.Module): def __init__(self, config: T5Config): super().__init__() - mlp_dims = config.d_ff or config.d_model * 4 self.self_attention = MultiHeadAttention(config) self.cross_attention = MultiHeadAttention(config) self.ln1 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln3 = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) - self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) + self.dense = DenseActivation(config) def __call__( self, @@ -217,9 +242,7 @@ class TransformerDecoderLayer(nn.Module): x = x + y y = self.ln3(x) - y = self.linear1(y) - y = mx.maximum(y, 0) - y = self.linear2(y) + y = self.dense(y) x = x + y return x, cache @@ -268,8 +291,9 @@ class T5(nn.Module): self.wte = nn.Embedding(config.vocab_size, config.d_model) self.encoder = TransformerEncoder(config) self.decoder = TransformerDecoder(config) - self.lm_head = OutputHead(config) self.tie_word_embeddings = config.tie_word_embeddings + if not self.tie_word_embeddings: + self.lm_head = OutputHead(config) self.model_dim = config.d_model def encode(self, inputs: mx.array): @@ -292,9 +316,12 @@ class T5(nn.Module): y, cache = self.decoder( inputs, memory=memory, mask=mask, memory_mask=None, cache=cache ) - if self.tie_word_embeddings: + if not self.tie_word_embeddings: y *= self.model_dim**-0.5 - return self.lm_head(y), cache + y = self.lm_head(y) + else: + y = y @ self.wte.weight.T + return y, cache def __call__( self, @@ -371,7 +398,19 @@ if __name__ == "__main__": "--model", type=str, help="Name of the T5 model.", - choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"], + choices=[ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + "google/flan-t5-small", + "google/flan-t5-base", + "google/flan-t5-large", + "google/flan-t5-xl", + "google/flan-t5-xxl", + "google/flan-t5-ul2", + ], default="t5-small", ) parser.add_argument(