mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Load position bias embeddings
This commit is contained in:
parent
62924d8135
commit
009ed0179c
@ -14,6 +14,8 @@ def replace_key(key: str) -> str:
|
|||||||
key = key.replace(".layer.1.DenseReluDense.wo.", ".linear2.")
|
key = key.replace(".layer.1.DenseReluDense.wo.", ".linear2.")
|
||||||
key = key.replace(".final_layer_norm.", ".ln.")
|
key = key.replace(".final_layer_norm.", ".ln.")
|
||||||
key = key.replace("shared.", "wte.")
|
key = key.replace("shared.", "wte.")
|
||||||
|
key = key.replace("encoder.layers.0.attention.relative_attention_bias.",
|
||||||
|
"position_bias.relative_attention_bias.")
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def convert():
|
def convert():
|
||||||
|
144
t5/t5.py
144
t5/t5.py
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@ -14,17 +15,151 @@ class ModelArgs:
|
|||||||
d_kv: int = 64
|
d_kv: int = 64
|
||||||
d_model: int = 512
|
d_model: int = 512
|
||||||
dropout_rate: int = 0.1
|
dropout_rate: int = 0.1
|
||||||
eos_token_id: int = 1
|
|
||||||
layer_norm_epsilon: float = 1e-06
|
layer_norm_epsilon: float = 1e-06
|
||||||
n_positions: int = 512
|
n_positions: int = 512
|
||||||
|
relative_attention_num_buckets: int = 32
|
||||||
num_heads: int = 8
|
num_heads: int = 8
|
||||||
num_layers: int = 6
|
num_layers: int = 6
|
||||||
decoder_start_token_id: int = 0
|
decoder_start_token_id: int = 0
|
||||||
|
eos_token_id: int = 1
|
||||||
pad_token_id: int = 0
|
pad_token_id: int = 0
|
||||||
relative_attention_num_buckets: int = 32
|
|
||||||
vocab_size: int = 32128
|
vocab_size: int = 32128
|
||||||
|
|
||||||
|
|
||||||
|
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
||||||
|
"""
|
||||||
|
Adapted from HF Tensorflow:
|
||||||
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
|
Translate relative position to a bucket number for relative attention. The relative position is defined as
|
||||||
|
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
|
||||||
|
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
|
||||||
|
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
|
||||||
|
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
|
||||||
|
This should allow for more graceful generalization to longer sequences than the model has been trained on
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relative_position: an int32 Tensor
|
||||||
|
bidirectional: a boolean - whether the attention is bidirectional
|
||||||
|
num_buckets: an integer
|
||||||
|
max_distance: an integer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
|
||||||
|
"""
|
||||||
|
relative_buckets = 0
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
relative_buckets += (relative_position > 0).to(mx.long) * num_buckets
|
||||||
|
relative_position = mx.abs(relative_position)
|
||||||
|
else:
|
||||||
|
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position))
|
||||||
|
# now relative_position is in the range [0, inf)
|
||||||
|
|
||||||
|
# half of the buckets are for exact increments in positions
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
relative_position_if_large = max_exact + (
|
||||||
|
mx.log(relative_position.float() / max_exact)
|
||||||
|
/ mx.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(mx.long)
|
||||||
|
relative_position_if_large = mx.min(
|
||||||
|
relative_position_if_large, mx.full_like(relative_position_if_large, num_buckets - 1)
|
||||||
|
)
|
||||||
|
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
|
||||||
|
class RelativePositionBias(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs, is_decoder: bool = False):
|
||||||
|
self.bidirectional = not is_decoder
|
||||||
|
self.num_buckets = config.relative_attention_num_buckets
|
||||||
|
self.max_distance = config.n_positions
|
||||||
|
self.n_heads = config.num_heads
|
||||||
|
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
|
||||||
|
|
||||||
|
def compute_bias(self, query_length, key_length, device=None):
|
||||||
|
"""Compute binned relative position bias"""
|
||||||
|
if device is None:
|
||||||
|
device = self.relative_attention_bias.weight.device
|
||||||
|
context_position = mx.arange(query_length, dtype=mx.long)[:, None]
|
||||||
|
memory_position = mx.arange(key_length, dtype=mx.long)[None, :]
|
||||||
|
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||||
|
relative_position_bucket = _relative_position_bucket(
|
||||||
|
relative_position, # shape (query_length, key_length)
|
||||||
|
bidirectional=self.bidirectional,
|
||||||
|
num_buckets=self.relative_attention_num_buckets,
|
||||||
|
max_distance=self.relative_attention_max_distance,
|
||||||
|
)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||||
|
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
num_heads: int,
|
||||||
|
query_input_dims: Optional[int] = None,
|
||||||
|
key_input_dims: Optional[int] = None,
|
||||||
|
value_input_dims: Optional[int] = None,
|
||||||
|
value_dims: Optional[int] = None,
|
||||||
|
value_output_dims: Optional[int] = None,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if (dims % num_heads) != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
query_input_dims = query_input_dims or dims
|
||||||
|
key_input_dims = key_input_dims or dims
|
||||||
|
value_input_dims = value_input_dims or key_input_dims
|
||||||
|
value_dims = value_dims or dims
|
||||||
|
value_output_dims = value_output_dims or dims
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.query_proj = nn.Linear(query_input_dims, dims, bias=bias)
|
||||||
|
self.key_proj = nn.Linear(key_input_dims, dims, bias=bias)
|
||||||
|
self.value_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
|
||||||
|
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
|
||||||
|
|
||||||
|
def __call__(self, queries, keys, values, mask=None):
|
||||||
|
queries = self.query_proj(queries)
|
||||||
|
keys = self.key_proj(keys)
|
||||||
|
values = self.value_proj(values)
|
||||||
|
|
||||||
|
num_heads = self.num_heads
|
||||||
|
B, L, D = queries.shape
|
||||||
|
_, S, _ = keys.shape
|
||||||
|
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||||
|
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Dimensions are [batch x num heads x sequence x hidden dim]
|
||||||
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
|
scores = (queries * scale) @ keys
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
scores = mx.softmax(scores, axis=-1)
|
||||||
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.out_proj(values_hat)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
||||||
|
indices = mx.arange(N)
|
||||||
|
mask = indices[:, None] < indices[None]
|
||||||
|
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
||||||
|
# TODO: Should replace this with finfo(dtype).min
|
||||||
|
mask = mask.astype(dtype) * -1e9
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
@ -49,7 +184,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = config.d_ff or config.d_model * 4
|
mlp_dims = config.d_ff or config.d_model * 4
|
||||||
self.attention = nn.MultiHeadAttention(config.d_model, config.num_heads)
|
self.attention = MultiHeadAttention(config.d_model, config.num_heads)
|
||||||
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln1 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln2 = LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
self.linear1 = nn.Linear(config.d_model, mlp_dims, bias=False)
|
||||||
@ -90,6 +225,7 @@ class T5(nn.Module):
|
|||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
self.encoder = TransformerEncoder(config)
|
self.encoder = TransformerEncoder(config)
|
||||||
|
self.position_bias = RelativePositionBias(config)
|
||||||
# self.decoder = TransformerDecoder(config)
|
# self.decoder = TransformerDecoder(config)
|
||||||
# self.lm_head = OutputHead(config)
|
# self.lm_head = OutputHead(config)
|
||||||
|
|
||||||
@ -103,7 +239,7 @@ class T5(nn.Module):
|
|||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if x.shape[1] > 1:
|
if x.shape[1] > 1:
|
||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
y = self.encoder(x, mask) #, cache)
|
y = self.encoder(x, mask) #, cache)
|
||||||
|
Loading…
Reference in New Issue
Block a user