mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Move position biases to attention module
This commit is contained in:
parent
d0497ddc0b
commit
330f024d1c
@ -13,10 +13,6 @@ SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
".layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
".position_bias.relative_attention_bias."
|
||||
),
|
||||
]
|
||||
|
||||
ENCODER_REPLACEMENT_PATTERNS = [
|
||||
|
79
t5/t5.py
79
t5/t5.py
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import math
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mlx.core as mx
|
||||
@ -79,12 +78,9 @@ class RelativePositionBias(nn.Module):
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.n_positions
|
||||
self.n_heads = config.num_heads
|
||||
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
|
||||
|
||||
def compute_bias(self, query_length, key_length, device=None):
|
||||
def compute_bias(self, query_length, key_length):
|
||||
"""Compute binned relative position bias"""
|
||||
if device is None:
|
||||
device = self.relative_attention_bias.weight.device
|
||||
context_position = mx.arange(query_length, dtype=mx.long)[:, None]
|
||||
memory_position = mx.arange(key_length, dtype=mx.long)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
@ -100,35 +96,17 @@ class RelativePositionBias(nn.Module):
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
num_heads: int,
|
||||
query_input_dims: Optional[int] = None,
|
||||
key_input_dims: Optional[int] = None,
|
||||
value_input_dims: Optional[int] = None,
|
||||
value_dims: Optional[int] = None,
|
||||
value_output_dims: Optional[int] = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
if (dims % num_heads) != 0:
|
||||
raise ValueError(
|
||||
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
|
||||
)
|
||||
|
||||
query_input_dims = query_input_dims or dims
|
||||
key_input_dims = key_input_dims or dims
|
||||
value_input_dims = value_input_dims or key_input_dims
|
||||
value_dims = value_dims or dims
|
||||
value_output_dims = value_output_dims or dims
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.query_proj = nn.Linear(query_input_dims, dims, bias=bias)
|
||||
self.key_proj = nn.Linear(key_input_dims, dims, bias=bias)
|
||||
self.value_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
|
||||
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
|
||||
self.num_heads = config.num_heads
|
||||
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
if has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embedding(
|
||||
config.relative_attention_num_buckets,
|
||||
config.num_heads)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.query_proj(queries)
|
||||
@ -181,10 +159,12 @@ class LayerNorm(nn.Module):
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.attention = MultiHeadAttention(config.d_model, config.num_heads)
|
||||
self.attention = MultiHeadAttention(
|
||||
config, has_relative_attention_bias=has_relative_attention_bias
|
||||
)
|
||||
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||
@ -208,11 +188,10 @@ class TransformerEncoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(config)
|
||||
for _ in range(config.num_layers)
|
||||
TransformerEncoderLayer(config, has_relative_attention_bias=i == 0)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.position_bias = RelativePositionBias(config)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for layer in self.layers:
|
||||
@ -223,11 +202,13 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
|
||||
super().__init__()
|
||||
mlp_dims = config.d_ff or config.d_model * 4
|
||||
self.self_attention = MultiHeadAttention(config.d_model, config.num_heads)
|
||||
self.cross_attention = MultiHeadAttention(config.d_model, config.num_heads)
|
||||
self.self_attention = MultiHeadAttention(
|
||||
config, has_relative_attention_bias=has_relative_attention_bias
|
||||
)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
@ -256,11 +237,10 @@ class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerDecoderLayer(config)
|
||||
for _ in range(config.num_layers)
|
||||
TransformerDecoderLayer(config, has_relative_attention_bias=i == 0)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.position_bias = RelativePositionBias(config)
|
||||
|
||||
def __call__(self, x, memory, x_mask, memory_mask):
|
||||
for layer in self.layers:
|
||||
@ -318,11 +298,18 @@ def load_model():
|
||||
weights = mx.load("weights.npz")
|
||||
current_weights = tree_flatten(model.parameters())
|
||||
weights_to_load = list(weights.items())
|
||||
current_weights_keys = set(k for k, _ in current_weights)
|
||||
weights_to_load_keys = set(k for k, _ in weights_to_load)
|
||||
current_weights_dict = dict(current_weights)
|
||||
current_weights_keys = set(current_weights_dict.keys())
|
||||
weights_to_load_dict = dict(weights_to_load)
|
||||
weights_to_load_keys = set(weights_to_load_dict.keys())
|
||||
print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys))
|
||||
print()
|
||||
print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys))
|
||||
for key in current_weights_keys & weights_to_load_keys:
|
||||
if weights_to_load_dict[key].shape != current_weights_dict[key].shape:
|
||||
print("Shape mismatch for key: ", key)
|
||||
print("Expected shape: ", current_weights_dict[key].shape)
|
||||
print("Loading shape: ", weights_to_load_dict[key].shape)
|
||||
model.update(tree_unflatten(weights_to_load))
|
||||
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
|
||||
return model, tokenizer
|
||||
|
Loading…
Reference in New Issue
Block a user