mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Move position biases to attention module
This commit is contained in:
parent
d0497ddc0b
commit
330f024d1c
@ -13,10 +13,6 @@ SHARED_REPLACEMENT_PATTERNS = [
|
|||||||
(".layer.1.layer_norm.", ".ln2."),
|
(".layer.1.layer_norm.", ".ln2."),
|
||||||
(".layer.2.layer_norm.", ".ln3."),
|
(".layer.2.layer_norm.", ".ln3."),
|
||||||
(".final_layer_norm.", ".ln."),
|
(".final_layer_norm.", ".ln."),
|
||||||
(
|
|
||||||
".layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
|
||||||
".position_bias.relative_attention_bias."
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
ENCODER_REPLACEMENT_PATTERNS = [
|
ENCODER_REPLACEMENT_PATTERNS = [
|
||||||
|
79
t5/t5.py
79
t5/t5.py
@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -79,12 +78,9 @@ class RelativePositionBias(nn.Module):
|
|||||||
self.num_buckets = config.relative_attention_num_buckets
|
self.num_buckets = config.relative_attention_num_buckets
|
||||||
self.max_distance = config.n_positions
|
self.max_distance = config.n_positions
|
||||||
self.n_heads = config.num_heads
|
self.n_heads = config.num_heads
|
||||||
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
|
|
||||||
|
|
||||||
def compute_bias(self, query_length, key_length, device=None):
|
def compute_bias(self, query_length, key_length):
|
||||||
"""Compute binned relative position bias"""
|
"""Compute binned relative position bias"""
|
||||||
if device is None:
|
|
||||||
device = self.relative_attention_bias.weight.device
|
|
||||||
context_position = mx.arange(query_length, dtype=mx.long)[:, None]
|
context_position = mx.arange(query_length, dtype=mx.long)[:, None]
|
||||||
memory_position = mx.arange(key_length, dtype=mx.long)[None, :]
|
memory_position = mx.arange(key_length, dtype=mx.long)[None, :]
|
||||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||||
@ -100,35 +96,17 @@ class RelativePositionBias(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
||||||
self,
|
|
||||||
dims: int,
|
|
||||||
num_heads: int,
|
|
||||||
query_input_dims: Optional[int] = None,
|
|
||||||
key_input_dims: Optional[int] = None,
|
|
||||||
value_input_dims: Optional[int] = None,
|
|
||||||
value_dims: Optional[int] = None,
|
|
||||||
value_output_dims: Optional[int] = None,
|
|
||||||
bias: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.num_heads = config.num_heads
|
||||||
if (dims % num_heads) != 0:
|
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
raise ValueError(
|
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
|
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
)
|
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
|
if has_relative_attention_bias:
|
||||||
query_input_dims = query_input_dims or dims
|
self.relative_attention_bias = nn.Embedding(
|
||||||
key_input_dims = key_input_dims or dims
|
config.relative_attention_num_buckets,
|
||||||
value_input_dims = value_input_dims or key_input_dims
|
config.num_heads)
|
||||||
value_dims = value_dims or dims
|
|
||||||
value_output_dims = value_output_dims or dims
|
|
||||||
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.query_proj = nn.Linear(query_input_dims, dims, bias=bias)
|
|
||||||
self.key_proj = nn.Linear(key_input_dims, dims, bias=bias)
|
|
||||||
self.value_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
|
|
||||||
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
|
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None):
|
def __call__(self, queries, keys, values, mask=None):
|
||||||
queries = self.query_proj(queries)
|
queries = self.query_proj(queries)
|
||||||
@ -181,10 +159,12 @@ class LayerNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
self.attention = MultiHeadAttention(config.d_model, config.num_heads)
|
self.attention = MultiHeadAttention(
|
||||||
|
config, has_relative_attention_bias=has_relative_attention_bias
|
||||||
|
)
|
||||||
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
@ -208,11 +188,10 @@ class TransformerEncoder(nn.Module):
|
|||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerEncoderLayer(config)
|
TransformerEncoderLayer(config, has_relative_attention_bias=i == 0)
|
||||||
for _ in range(config.num_layers)
|
for i in range(config.num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.position_bias = RelativePositionBias(config)
|
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
@ -223,11 +202,13 @@ class TransformerEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Module):
|
class TransformerDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
self.self_attention = MultiHeadAttention(config.d_model, config.num_heads)
|
self.self_attention = MultiHeadAttention(
|
||||||
self.cross_attention = MultiHeadAttention(config.d_model, config.num_heads)
|
config, has_relative_attention_bias=has_relative_attention_bias
|
||||||
|
)
|
||||||
|
self.cross_attention = MultiHeadAttention(config)
|
||||||
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
@ -256,11 +237,10 @@ class TransformerDecoder(nn.Module):
|
|||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = [
|
self.layers = [
|
||||||
TransformerDecoderLayer(config)
|
TransformerDecoderLayer(config, has_relative_attention_bias=i == 0)
|
||||||
for _ in range(config.num_layers)
|
for i in range(config.num_layers)
|
||||||
]
|
]
|
||||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.position_bias = RelativePositionBias(config)
|
|
||||||
|
|
||||||
def __call__(self, x, memory, x_mask, memory_mask):
|
def __call__(self, x, memory, x_mask, memory_mask):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
@ -318,11 +298,18 @@ def load_model():
|
|||||||
weights = mx.load("weights.npz")
|
weights = mx.load("weights.npz")
|
||||||
current_weights = tree_flatten(model.parameters())
|
current_weights = tree_flatten(model.parameters())
|
||||||
weights_to_load = list(weights.items())
|
weights_to_load = list(weights.items())
|
||||||
current_weights_keys = set(k for k, _ in current_weights)
|
current_weights_dict = dict(current_weights)
|
||||||
weights_to_load_keys = set(k for k, _ in weights_to_load)
|
current_weights_keys = set(current_weights_dict.keys())
|
||||||
|
weights_to_load_dict = dict(weights_to_load)
|
||||||
|
weights_to_load_keys = set(weights_to_load_dict.keys())
|
||||||
print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys))
|
print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys))
|
||||||
print()
|
print()
|
||||||
print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys))
|
print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys))
|
||||||
|
for key in current_weights_keys & weights_to_load_keys:
|
||||||
|
if weights_to_load_dict[key].shape != current_weights_dict[key].shape:
|
||||||
|
print("Shape mismatch for key: ", key)
|
||||||
|
print("Expected shape: ", current_weights_dict[key].shape)
|
||||||
|
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||||
model.update(tree_unflatten(weights_to_load))
|
model.update(tree_unflatten(weights_to_load))
|
||||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
Loading…
Reference in New Issue
Block a user