diff --git a/t5/convert.py b/t5/convert.py index 0977c917..11acc638 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -14,6 +14,8 @@ def replace_key(key: str) -> str: key = key.replace(".layer.1.DenseReluDense.wo.", ".linear2.") key = key.replace(".final_layer_norm.", ".ln.") key = key.replace("shared.", "wte.") + key = key.replace("encoder.layers.0.attention.relative_attention_bias.", + "position_bias.relative_attention_bias.") return key def convert(): diff --git a/t5/t5.py b/t5/t5.py index b2ec7717..172f6116 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,4 +1,5 @@ import argparse +import math from typing import Optional from dataclasses import dataclass @@ -14,17 +15,151 @@ class ModelArgs: d_kv: int = 64 d_model: int = 512 dropout_rate: int = 0.1 - eos_token_id: int = 1 layer_norm_epsilon: float = 1e-06 n_positions: int = 512 + relative_attention_num_buckets: int = 32 num_heads: int = 8 num_layers: int = 6 decoder_start_token_id: int = 0 + eos_token_id: int = 1 pad_token_id: int = 0 - relative_attention_num_buckets: int = 32 vocab_size: int = 32128 +def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from HF Tensorflow: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(mx.long) * num_buckets + relative_position = mx.abs(relative_position) + else: + relative_position = -mx.min(relative_position, mx.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + mx.log(relative_position.float() / max_exact) + / mx.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(mx.long) + relative_position_if_large = mx.min( + relative_position_if_large, mx.full_like(relative_position_if_large, num_buckets - 1) + ) + relative_buckets += mx.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + +class RelativePositionBias(nn.Module): + def __init__(self, config: ModelArgs, is_decoder: bool = False): + self.bidirectional = not is_decoder + self.num_buckets = config.relative_attention_num_buckets + self.max_distance = config.n_positions + self.n_heads = config.num_heads + self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads) + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = mx.arange(query_length, dtype=mx.long)[:, None] + memory_position = mx.arange(key_length, dtype=mx.long)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = _relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=self.bidirectional, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.query_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.key_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.value_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + # Dimensions are [batch x num heads x sequence x hidden dim] + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + @staticmethod + def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32): + indices = mx.arange(N) + mask = indices[:, None] < indices[None] + # usually inf but 1e9 is as good and softmax(full(1e9)) != nan + # TODO: Should replace this with finfo(dtype).min + mask = mask.astype(dtype) * -1e9 + return mask class LayerNorm(nn.Module): @@ -49,7 +184,7 @@ class TransformerEncoderLayer(nn.Module): def __init__(self, config): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 - self.attention = nn.MultiHeadAttention(config.d_model, config.num_heads) + self.attention = MultiHeadAttention(config.d_model, config.num_heads) self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False) @@ -90,6 +225,7 @@ class T5(nn.Module): def __init__(self, config: ModelArgs): self.wte = nn.Embedding(config.vocab_size, config.d_model) self.encoder = TransformerEncoder(config) + self.position_bias = RelativePositionBias(config) # self.decoder = TransformerDecoder(config) # self.lm_head = OutputHead(config) @@ -103,7 +239,7 @@ class T5(nn.Module): mask = None if x.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(x.dtype) y = self.encoder(x, mask) #, cache)