Move position biases to attention module

This commit is contained in:
Juarez Bochi 2023-12-15 11:30:17 -05:00
parent d0497ddc0b
commit 330f024d1c
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 33 additions and 50 deletions

View File

@ -13,10 +13,6 @@ SHARED_REPLACEMENT_PATTERNS = [
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
".layers.0.layer.0.SelfAttention.relative_attention_bias.",
".position_bias.relative_attention_bias."
),
]
ENCODER_REPLACEMENT_PATTERNS = [

View File

@ -1,6 +1,5 @@
import argparse
import math
from typing import Optional
from dataclasses import dataclass
import mlx.core as mx
@ -79,12 +78,9 @@ class RelativePositionBias(nn.Module):
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.n_positions
self.n_heads = config.num_heads
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
def compute_bias(self, query_length, key_length, device=None):
def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = mx.arange(query_length, dtype=mx.long)[:, None]
memory_position = mx.arange(key_length, dtype=mx.long)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
@ -100,35 +96,17 @@ class RelativePositionBias(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
bias: bool = False,
):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
super().__init__()
if (dims % num_heads) != 0:
raise ValueError(
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
)
query_input_dims = query_input_dims or dims
key_input_dims = key_input_dims or dims
value_input_dims = value_input_dims or key_input_dims
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = nn.Linear(query_input_dims, dims, bias=bias)
self.key_proj = nn.Linear(key_input_dims, dims, bias=bias)
self.value_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
if has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(
config.relative_attention_num_buckets,
config.num_heads)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
@ -181,10 +159,12 @@ class LayerNorm(nn.Module):
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: ModelArgs):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.attention = MultiHeadAttention(config.d_model, config.num_heads)
self.attention = MultiHeadAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
@ -208,11 +188,10 @@ class TransformerEncoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.layers = [
TransformerEncoderLayer(config)
for _ in range(config.num_layers)
TransformerEncoderLayer(config, has_relative_attention_bias=i == 0)
for i in range(config.num_layers)
]
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.position_bias = RelativePositionBias(config)
def __call__(self, x, mask):
for layer in self.layers:
@ -223,11 +202,13 @@ class TransformerEncoder(nn.Module):
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs):
def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.self_attention = MultiHeadAttention(config.d_model, config.num_heads)
self.cross_attention = MultiHeadAttention(config.d_model, config.num_heads)
self.self_attention = MultiHeadAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
@ -256,11 +237,10 @@ class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.layers = [
TransformerDecoderLayer(config)
for _ in range(config.num_layers)
TransformerDecoderLayer(config, has_relative_attention_bias=i == 0)
for i in range(config.num_layers)
]
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.position_bias = RelativePositionBias(config)
def __call__(self, x, memory, x_mask, memory_mask):
for layer in self.layers:
@ -318,11 +298,18 @@ def load_model():
weights = mx.load("weights.npz")
current_weights = tree_flatten(model.parameters())
weights_to_load = list(weights.items())
current_weights_keys = set(k for k, _ in current_weights)
weights_to_load_keys = set(k for k, _ in weights_to_load)
current_weights_dict = dict(current_weights)
current_weights_keys = set(current_weights_dict.keys())
weights_to_load_dict = dict(weights_to_load)
weights_to_load_keys = set(weights_to_load_dict.keys())
print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys))
print()
print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys))
for key in current_weights_keys & weights_to_load_keys:
if weights_to_load_dict[key].shape != current_weights_dict[key].shape:
print("Shape mismatch for key: ", key)
print("Expected shape: ", current_weights_dict[key].shape)
print("Loading shape: ", weights_to_load_dict[key].shape)
model.update(tree_unflatten(weights_to_load))
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
return model, tokenizer