diff --git a/t5/convert.py b/t5/convert.py index 589ec961..7531ec4d 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -13,10 +13,6 @@ SHARED_REPLACEMENT_PATTERNS = [ (".layer.1.layer_norm.", ".ln2."), (".layer.2.layer_norm.", ".ln3."), (".final_layer_norm.", ".ln."), - ( - ".layers.0.layer.0.SelfAttention.relative_attention_bias.", - ".position_bias.relative_attention_bias." - ), ] ENCODER_REPLACEMENT_PATTERNS = [ diff --git a/t5/t5.py b/t5/t5.py index cc9895ff..c4048c86 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,6 +1,5 @@ import argparse import math -from typing import Optional from dataclasses import dataclass import mlx.core as mx @@ -79,12 +78,9 @@ class RelativePositionBias(nn.Module): self.num_buckets = config.relative_attention_num_buckets self.max_distance = config.n_positions self.n_heads = config.num_heads - self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads) - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" - if device is None: - device = self.relative_attention_bias.weight.device context_position = mx.arange(query_length, dtype=mx.long)[:, None] memory_position = mx.arange(key_length, dtype=mx.long)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) @@ -100,35 +96,17 @@ class RelativePositionBias(nn.Module): class MultiHeadAttention(nn.Module): - def __init__( - self, - dims: int, - num_heads: int, - query_input_dims: Optional[int] = None, - key_input_dims: Optional[int] = None, - value_input_dims: Optional[int] = None, - value_dims: Optional[int] = None, - value_output_dims: Optional[int] = None, - bias: bool = False, - ): + def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False): super().__init__() - - if (dims % num_heads) != 0: - raise ValueError( - f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" - ) - - query_input_dims = query_input_dims or dims - key_input_dims = key_input_dims or dims - value_input_dims = value_input_dims or key_input_dims - value_dims = value_dims or dims - value_output_dims = value_output_dims or dims - - self.num_heads = num_heads - self.query_proj = nn.Linear(query_input_dims, dims, bias=bias) - self.key_proj = nn.Linear(key_input_dims, dims, bias=bias) - self.value_proj = nn.Linear(value_input_dims, value_dims, bias=bias) - self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + self.num_heads = config.num_heads + self.query_proj = nn.Linear(config.d_model, config.d_model, bias=False) + self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False) + self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False) + self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) + if has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + config.relative_attention_num_buckets, + config.num_heads) def __call__(self, queries, keys, values, mask=None): queries = self.query_proj(queries) @@ -181,10 +159,12 @@ class LayerNorm(nn.Module): class TransformerEncoderLayer(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 - self.attention = MultiHeadAttention(config.d_model, config.num_heads) + self.attention = MultiHeadAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) @@ -208,11 +188,10 @@ class TransformerEncoder(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.layers = [ - TransformerEncoderLayer(config) - for _ in range(config.num_layers) + TransformerEncoderLayer(config, has_relative_attention_bias=i == 0) + for i in range(config.num_layers) ] self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.position_bias = RelativePositionBias(config) def __call__(self, x, mask): for layer in self.layers: @@ -223,11 +202,13 @@ class TransformerEncoder(nn.Module): class TransformerDecoderLayer(nn.Module): - def __init__(self, config: ModelArgs): + def __init__(self, config: ModelArgs, has_relative_attention_bias: bool = False): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 - self.self_attention = MultiHeadAttention(config.d_model, config.num_heads) - self.cross_attention = MultiHeadAttention(config.d_model, config.num_heads) + self.self_attention = MultiHeadAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.cross_attention = MultiHeadAttention(config) self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln3 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -256,11 +237,10 @@ class TransformerDecoder(nn.Module): def __init__(self, config: ModelArgs): super().__init__() self.layers = [ - TransformerDecoderLayer(config) - for _ in range(config.num_layers) + TransformerDecoderLayer(config, has_relative_attention_bias=i == 0) + for i in range(config.num_layers) ] self.ln = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) - self.position_bias = RelativePositionBias(config) def __call__(self, x, memory, x_mask, memory_mask): for layer in self.layers: @@ -318,11 +298,18 @@ def load_model(): weights = mx.load("weights.npz") current_weights = tree_flatten(model.parameters()) weights_to_load = list(weights.items()) - current_weights_keys = set(k for k, _ in current_weights) - weights_to_load_keys = set(k for k, _ in weights_to_load) + current_weights_dict = dict(current_weights) + current_weights_keys = set(current_weights_dict.keys()) + weights_to_load_dict = dict(weights_to_load) + weights_to_load_keys = set(weights_to_load_dict.keys()) print("Missing weights: ", sorted(current_weights_keys - weights_to_load_keys)) print() print("Weights ignored: ", sorted(weights_to_load_keys - current_weights_keys)) + for key in current_weights_keys & weights_to_load_keys: + if weights_to_load_dict[key].shape != current_weights_dict[key].shape: + print("Shape mismatch for key: ", key) + print("Expected shape: ", current_weights_dict[key].shape) + print("Loading shape: ", weights_to_load_dict[key].shape) model.update(tree_unflatten(weights_to_load)) tokenizer = AutoTokenizer.from_pretrained("t5-small", trust_remote_code=True) return model, tokenizer