mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Load decoder weights
This commit is contained in:
parent
009ed0179c
commit
d0497ddc0b
@ -1,23 +1,49 @@
|
|||||||
from transformers import T5ForConditionalGeneration
|
from transformers import T5ForConditionalGeneration
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
SHARED_REPLACEMENT_PATTERNS = [
|
||||||
|
(".block.", ".layers."),
|
||||||
|
(".k.", ".key_proj."),
|
||||||
|
(".o.", ".out_proj."),
|
||||||
|
(".q.", ".query_proj."),
|
||||||
|
(".v.", ".value_proj."),
|
||||||
|
("shared.", "wte."),
|
||||||
|
(".layer.0.layer_norm.", ".ln1."),
|
||||||
|
(".layer.1.layer_norm.", ".ln2."),
|
||||||
|
(".layer.2.layer_norm.", ".ln3."),
|
||||||
|
(".final_layer_norm.", ".ln."),
|
||||||
|
(
|
||||||
|
".layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||||
|
".position_bias.relative_attention_bias."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
ENCODER_REPLACEMENT_PATTERNS = [
|
||||||
|
(".layer.0.SelfAttention.", ".attention."),
|
||||||
|
(".layer.1.DenseReluDense.wi.", ".linear1."),
|
||||||
|
(".layer.1.DenseReluDense.wo.", ".linear2."),
|
||||||
|
]
|
||||||
|
|
||||||
|
DECODER_REPLACEMENT_PATTERNS = [
|
||||||
|
(".layer.0.SelfAttention.", ".self_attention."),
|
||||||
|
(".layer.1.EncDecAttention.", ".cross_attention."),
|
||||||
|
(".layer.2.DenseReluDense.wi.", ".linear1."),
|
||||||
|
(".layer.2.DenseReluDense.wo.", ".linear2."),
|
||||||
|
]
|
||||||
|
|
||||||
def replace_key(key: str) -> str:
|
def replace_key(key: str) -> str:
|
||||||
key = key.replace(".block.", ".layers.")
|
for old, new in SHARED_REPLACEMENT_PATTERNS:
|
||||||
key = key.replace(".layer.0.SelfAttention.", ".attention.")
|
key = key.replace(old, new)
|
||||||
key = key.replace(".k.", ".key_proj.")
|
if key.startswith("encoder."):
|
||||||
key = key.replace(".o.", ".out_proj.")
|
for old, new in ENCODER_REPLACEMENT_PATTERNS:
|
||||||
key = key.replace(".q.", ".query_proj.")
|
key = key.replace(old, new)
|
||||||
key = key.replace(".v.", ".value_proj.")
|
elif key.startswith("decoder."):
|
||||||
key = key.replace(".layer.0.layer_norm.", ".ln1.")
|
for old, new in DECODER_REPLACEMENT_PATTERNS:
|
||||||
key = key.replace(".layer.1.layer_norm.", ".ln2.")
|
key = key.replace(old, new)
|
||||||
key = key.replace(".layer.1.DenseReluDense.wi.", ".linear1.")
|
|
||||||
key = key.replace(".layer.1.DenseReluDense.wo.", ".linear2.")
|
|
||||||
key = key.replace(".final_layer_norm.", ".ln.")
|
|
||||||
key = key.replace("shared.", "wte.")
|
|
||||||
key = key.replace("encoder.layers.0.attention.relative_attention_bias.",
|
|
||||||
"position_bias.relative_attention_bias.")
|
|
||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
def convert():
|
def convert():
|
||||||
model = T5ForConditionalGeneration.from_pretrained(
|
model = T5ForConditionalGeneration.from_pretrained(
|
||||||
"t5-small", torch_dtype="auto"
|
"t5-small", torch_dtype="auto"
|
||||||
|
54
t5/t5.py
54
t5/t5.py
@ -181,7 +181,7 @@ class LayerNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
self.attention = MultiHeadAttention(config.d_model, config.num_heads)
|
self.attention = MultiHeadAttention(config.d_model, config.num_heads)
|
||||||
@ -212,6 +212,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
for _ in range(config.num_layers)
|
for _ in range(config.num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.position_bias = RelativePositionBias(config)
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
@ -221,12 +222,59 @@ class TransformerEncoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
|
self.self_attention = MultiHeadAttention(config.d_model, config.num_heads)
|
||||||
|
self.cross_attention = MultiHeadAttention(config.d_model, config.num_heads)
|
||||||
|
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
|
self.linear2 = nn.Linear(mlp_dims, config.d_model, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
|
y = self.ln1(x)
|
||||||
|
y = self.self_attention(y, y, y, x_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln2(x)
|
||||||
|
y = self.cross_attention(x, memory, memory, memory_mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln3(x)
|
||||||
|
y = self.linear1(y)
|
||||||
|
y = mx.maximum(y, 0)
|
||||||
|
y = self.linear2(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = [
|
||||||
|
TransformerDecoderLayer(config)
|
||||||
|
for _ in range(config.num_layers)
|
||||||
|
]
|
||||||
|
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
|
self.position_bias = RelativePositionBias(config)
|
||||||
|
|
||||||
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, memory, x_mask, memory_mask)
|
||||||
|
x = self.ln(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class T5(nn.Module):
|
class T5(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
self.encoder = TransformerEncoder(config)
|
self.encoder = TransformerEncoder(config)
|
||||||
self.position_bias = RelativePositionBias(config)
|
self.decoder = TransformerDecoder(config)
|
||||||
# self.decoder = TransformerDecoder(config)
|
|
||||||
# self.lm_head = OutputHead(config)
|
# self.lm_head = OutputHead(config)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
Loading…
Reference in New Issue
Block a user