mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
bug fix with bidirectional only for encoder, add offset to position bias
This commit is contained in:
parent
688a6e1e78
commit
c468edc4e3
40
t5/t5.py
40
t5/t5.py
@ -58,7 +58,7 @@ def _relative_position_bucket(
|
||||
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
||||
relative_position = mx.abs(relative_position)
|
||||
else:
|
||||
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position))
|
||||
relative_position = -mx.minimum(relative_position, mx.zeros_like(relative_position))
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -79,8 +79,8 @@ def _relative_position_bucket(
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, config: ModelArgs, is_decoder: bool = False):
|
||||
self.bidirectional = not is_decoder
|
||||
def __init__(self, config: ModelArgs, bidirectional: bool):
|
||||
self.bidirectional = False #bidirectional
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.n_heads = config.num_heads
|
||||
@ -88,26 +88,30 @@ class RelativePositionBias(nn.Module):
|
||||
config.relative_attention_num_buckets, config.num_heads
|
||||
)
|
||||
|
||||
def __call__(self, query_length, key_length):
|
||||
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
|
||||
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
|
||||
context_position = mx.arange(offset, query_length)[:, None]
|
||||
memory_position = mx.arange(key_length)[None, :]
|
||||
|
||||
|
||||
# shape (query_length, key_length)
|
||||
relative_position = (
|
||||
memory_position - context_position
|
||||
) # shape (query_length, key_length)
|
||||
)
|
||||
relative_position_bucket = _relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
relative_position,
|
||||
bidirectional=self.bidirectional,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
|
||||
# shape (query_length, key_length, num_heads)
|
||||
values = self.embeddings(
|
||||
relative_position_bucket
|
||||
) # shape (query_length, key_length, num_heads)
|
||||
values = mx.expand_dims(
|
||||
values.transpose(2, 0, 1), 0
|
||||
) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
)
|
||||
|
||||
# shape (num_heads, query_length, key_length)
|
||||
return values.transpose(2, 0, 1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
@ -197,7 +201,7 @@ class TransformerEncoder(nn.Module):
|
||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||
|
||||
def __call__(self, x):
|
||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||
@ -250,7 +254,7 @@ class TransformerDecoder(nn.Module):
|
||||
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
||||
]
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||
|
||||
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||
if cache is not None:
|
||||
@ -260,9 +264,7 @@ class TransformerDecoder(nn.Module):
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
T = offset + x.shape[1]
|
||||
# TODO, add offset to RelativePositionBias class to avoid wasted work
|
||||
pos_bias = self.relative_attention_bias(T, T)
|
||||
pos_bias = pos_bias[:, :, -x.shape[1]:, :]
|
||||
pos_bias = self.relative_attention_bias(T, T, offset=offset)
|
||||
if mask is not None:
|
||||
mask += pos_bias
|
||||
else:
|
||||
@ -337,9 +339,7 @@ def generate(
|
||||
y = decoder_inputs
|
||||
while True:
|
||||
logits, cache = model.decode(y[None], memory, cache=cache)
|
||||
# logits, cache = model.decode(decoder_inputs[None], memory, cache=cache)
|
||||
y = sample(logits[:, -1, :])
|
||||
#decoder_inputs = mx.concatenate([decoder_inputs, y])
|
||||
yield y.squeeze()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user