Move position biases to attention module

This commit is contained in:
Juarez Bochi 2023-12-15 11:30:17 -05:00
parent d0497ddc0b
commit 330f024d1c
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 33 additions and 50 deletions

View File

@ -13,10 +13,6 @@ SHARED_REPLACEMENT_PATTERNS = [
(".layer.1.layer_norm.", ".ln2."), (".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."), (".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."), (".final_layer_norm.", ".ln."),
(
".layers.0.layer.0.SelfAttention.relative_attention_bias.",
".position_bias.relative_attention_bias."
),
] ]
ENCODER_REPLACEMENT_PATTERNS = [ ENCODER_REPLACEMENT_PATTERNS = [

View File

@ -1,6 +1,5 @@
import argparse import argparse
import math import math
from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
@ -79,12 +78,9 @@ class RelativePositionBias(nn.Module):
self.num_buckets = config.relative_attention_num_buckets self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.n_positions self.max_distance = config.n_positions
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
def compute_bias(self, query_length, key_length, device=None): def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = mx.arange(query_length, dtype=mx.long)[:, None] context_position = mx.arange(query_length, dtype=mx.long)[:, None]
memory_position = mx.arange(key_length, dtype=mx.long)[None, :] memory_position = mx.arange(key_length, dtype=mx.long)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
@ -100,35 +96,17 @@ class RelativePositionBias(nn.Module):
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__( def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
self,
dims: int,
num_heads: int,
query_input_dims: Optional[int] = None,
key_input_dims: Optional[int] = None,
value_input_dims: Optional[int] = None,
value_dims: Optional[int] = None,
value_output_dims: Optional[int] = None,
bias: bool = False,
):
super().__init__() super().__init__()
self.num_heads = config.num_heads
if (dims % num_heads) != 0: self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False)
raise ValueError( self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
) self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
if has_relative_attention_bias:
query_input_dims = query_input_dims or dims self.relative_attention_bias = nn.Embedding(
key_input_dims = key_input_dims or dims config.relative_attention_num_buckets,
value_input_dims = value_input_dims or key_input_dims config.num_heads)
value_dims = value_dims or dims
value_output_dims = value_output_dims or dims
self.num_heads = num_heads
self.query_proj = nn.Linear(query_input_dims, dims, bias=bias)
self.key_proj = nn.Linear(key_input_dims, dims, bias=bias)
self.value_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
def __call__(self, queries, keys, values, mask=None): def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries) queries = self.query_proj(queries)
@ -181,10 +159,12 @@ class LayerNorm(nn.Module):
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
super().__init__() super().__init__()
mlp_dims = config.d_ff or config.d_model * 4 mlp_dims = config.d_ff or config.d_model * 4
self.attention = MultiHeadAttention(config.d_model, config.num_heads) self.attention = MultiHeadAttention(
config, has_relative_attention_bias=has_relative_attention_bias
)
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
@ -208,11 +188,10 @@ class TransformerEncoder(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerEncoderLayer(config) TransformerEncoderLayer(config, has_relative_attention_bias=i == 0)
for _ in range(config.num_layers) for i in range(config.num_layers)
] ]
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.position_bias = RelativePositionBias(config)
def __call__(self, x, mask): def __call__(self, x, mask):
for layer in self.layers: for layer in self.layers:
@ -223,11 +202,13 @@ class TransformerEncoder(nn.Module):
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False):
super().__init__() super().__init__()
mlp_dims = config.d_ff or config.d_model * 4 mlp_dims = config.d_ff or config.d_model * 4
self.self_attention = MultiHeadAttention(config.d_model, config.num_heads) self.self_attention = MultiHeadAttention(
self.cross_attention = MultiHeadAttention(config.d_model, config.num_heads) config, has_relative_attention_bias=has_relative_attention_bias
)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
@ -256,11 +237,10 @@ class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
self.layers = [ self.layers = [
TransformerDecoderLayer(config) TransformerDecoderLayer(config, has_relative_attention_bias=i == 0)
for _ in range(config.num_layers) for i in range(config.num_layers)
] ]
self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.position_bias = RelativePositionBias(config)
def __call__(self, x, memory, x_mask, memory_mask): def __call__(self, x, memory, x_mask, memory_mask):
for layer in self.layers: for layer in self.layers:
@ -318,11 +298,18 @@ def load_model():
weights = mx.load("weights.npz") weights = mx.load("weights.npz")
current_weights = tree_flatten(model.parameters()) current_weights = tree_flatten(model.parameters())
weights_to_load = list(weights.items()) weights_to_load = list(weights.items())
current_weights_keys = set(k for k, _ in current_weights) current_weights_dict = dict(current_weights)
weights_to_load_keys = set(k for k, _ in weights_to_load) current_weights_keys = set(current_weights_dict.keys())
weights_to_load_dict = dict(weights_to_load)
weights_to_load_keys = set(weights_to_load_dict.keys())
print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys)) print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys))
print() print()
print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys)) print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys))
for key in current_weights_keys & weights_to_load_keys:
if weights_to_load_dict[key].shape != current_weights_dict[key].shape:
print("Shape mismatch for key: ", key)
print("Expected shape: ", current_weights_dict[key].shape)
print("Loading shape: ", weights_to_load_dict[key].shape)
model.update(tree_unflatten(weights_to_load)) model.update(tree_unflatten(weights_to_load))
tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True)
return model, tokenizer return model, tokenizer