diff --git a/t5/convert.py b/t5/convert.py index 11acc638..589ec961 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -1,23 +1,49 @@ from transformers import T5ForConditionalGeneration import numpy as np + +SHARED_REPLACEMENT_PATTERNS = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + ".layers.0.layer.0.SelfAttention.relative_attention_bias.", + ".position_bias.relative_attention_bias." + ), +] + +ENCODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.wi.", ".linear1."), + (".layer.1.DenseReluDense.wo.", ".linear2."), +] + +DECODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".self_attention."), + (".layer.1.EncDecAttention.", ".cross_attention."), + (".layer.2.DenseReluDense.wi.", ".linear1."), + (".layer.2.DenseReluDense.wo.", ".linear2."), +] + def replace_key(key: str) -> str: - key = key.replace(".block.", ".layers.") - key = key.replace(".layer.0.SelfAttention.", ".attention.") - key = key.replace(".k.", ".key_proj.") - key = key.replace(".o.", ".out_proj.") - key = key.replace(".q.", ".query_proj.") - key = key.replace(".v.", ".value_proj.") - key = key.replace(".layer.0.layer_norm.", ".ln1.") - key = key.replace(".layer.1.layer_norm.", ".ln2.") - key = key.replace(".layer.1.DenseReluDense.wi.", ".linear1.") - key = key.replace(".layer.1.DenseReluDense.wo.", ".linear2.") - key = key.replace(".final_layer_norm.", ".ln.") - key = key.replace("shared.", "wte.") - key = key.replace("encoder.layers.0.attention.relative_attention_bias.", - "position_bias.relative_attention_bias.") + for old, new in SHARED_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + if key.startswith("encoder."): + for old, new in ENCODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) + elif key.startswith("decoder."): + for old, new in DECODER_REPLACEMENT_PATTERNS: + key = key.replace(old, new) return key + def convert(): model = T5ForConditionalGeneration.from_pretrained( "t5-small", torch_dtype="auto" diff --git a/t5/t5.py b/t5/t5.py index 172f6116..cc9895ff 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -181,7 +181,7 @@ class LayerNorm(nn.Module): class TransformerEncoderLayer(nn.Module): - def __init__(self, config): + def __init__(self, config: ModelArgs): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 self.attention = MultiHeadAttention(config.d_model, config.num_heads) @@ -212,6 +212,7 @@ class TransformerEncoder(nn.Module): for _ in range(config.num_layers) ] self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.position_bias = RelativePositionBias(config) def __call__(self, x, mask): for layer in self.layers: @@ -221,12 +222,59 @@ class TransformerEncoder(nn.Module): return x +class TransformerDecoderLayer(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + mlp_dims = config.d_ff or config.d_model * 4 + self.self_attention = MultiHeadAttention(config.d_model, config.num_heads) + self.cross_attention = MultiHeadAttention(config.d_model, config.num_heads) + self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False) + + def __call__(self, x, memory, x_mask, memory_mask): + y = self.ln1(x) + y = self.self_attention(y, y, y, x_mask) + x = x + y + + y = self.ln2(x) + y = self.cross_attention(x, memory, memory, memory_mask) + x = x + y + + y = self.ln3(x) + y = self.linear1(y) + y = mx.maximum(y, 0) + y = self.linear2(y) + x = x + y + + return x + + +class TransformerDecoder(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.layers = [ + TransformerDecoderLayer(config) + for _ in range(config.num_layers) + ] + self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.position_bias = RelativePositionBias(config) + + def __call__(self, x, memory, x_mask, memory_mask): + for layer in self.layers: + x = layer(x, memory, x_mask, memory_mask) + x = self.ln(x) + + return x + + class T5(nn.Module): def __init__(self, config: ModelArgs): self.wte = nn.Embedding(config.vocab_size, config.d_model) self.encoder = TransformerEncoder(config) - self.position_bias = RelativePositionBias(config) - # self.decoder = TransformerDecoder(config) + self.decoder = TransformerDecoder(config) # self.lm_head = OutputHead(config) def __call__(