mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
bug fix with bidirectional only for encoder, add offset to position bias
This commit is contained in:
parent
688a6e1e78
commit
c468edc4e3
40
t5/t5.py
40
t5/t5.py
@ -58,7 +58,7 @@ def _relative_position_bucket(
|
|||||||
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
||||||
relative_position = mx.abs(relative_position)
|
relative_position = mx.abs(relative_position)
|
||||||
else:
|
else:
|
||||||
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position))
|
relative_position = -mx.minimum(relative_position, mx.zeros_like(relative_position))
|
||||||
# now relative_position is in the range [0, inf)
|
# now relative_position is in the range [0, inf)
|
||||||
|
|
||||||
# half of the buckets are for exact increments in positions
|
# half of the buckets are for exact increments in positions
|
||||||
@ -79,8 +79,8 @@ def _relative_position_bucket(
|
|||||||
|
|
||||||
|
|
||||||
class RelativePositionBias(nn.Module):
|
class RelativePositionBias(nn.Module):
|
||||||
def __init__(self, config: ModelArgs, is_decoder: bool = False):
|
def __init__(self, config: ModelArgs, bidirectional: bool):
|
||||||
self.bidirectional = not is_decoder
|
self.bidirectional = False #bidirectional
|
||||||
self.num_buckets = config.relative_attention_num_buckets
|
self.num_buckets = config.relative_attention_num_buckets
|
||||||
self.max_distance = config.relative_attention_max_distance
|
self.max_distance = config.relative_attention_max_distance
|
||||||
self.n_heads = config.num_heads
|
self.n_heads = config.num_heads
|
||||||
@ -88,26 +88,30 @@ class RelativePositionBias(nn.Module):
|
|||||||
config.relative_attention_num_buckets, config.num_heads
|
config.relative_attention_num_buckets, config.num_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, query_length, key_length):
|
def __call__(self, query_length: int, key_length: int, offset: int = 0):
|
||||||
"""Compute binned relative position bias"""
|
"""Compute binned relative position bias"""
|
||||||
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
|
context_position = mx.arange(offset, query_length)[:, None]
|
||||||
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
|
memory_position = mx.arange(key_length)[None, :]
|
||||||
|
|
||||||
|
|
||||||
|
# shape (query_length, key_length)
|
||||||
relative_position = (
|
relative_position = (
|
||||||
memory_position - context_position
|
memory_position - context_position
|
||||||
) # shape (query_length, key_length)
|
)
|
||||||
relative_position_bucket = _relative_position_bucket(
|
relative_position_bucket = _relative_position_bucket(
|
||||||
relative_position, # shape (query_length, key_length)
|
relative_position,
|
||||||
bidirectional=self.bidirectional,
|
bidirectional=self.bidirectional,
|
||||||
num_buckets=self.num_buckets,
|
num_buckets=self.num_buckets,
|
||||||
max_distance=self.max_distance,
|
max_distance=self.max_distance,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# shape (query_length, key_length, num_heads)
|
||||||
values = self.embeddings(
|
values = self.embeddings(
|
||||||
relative_position_bucket
|
relative_position_bucket
|
||||||
) # shape (query_length, key_length, num_heads)
|
)
|
||||||
values = mx.expand_dims(
|
|
||||||
values.transpose(2, 0, 1), 0
|
# shape (num_heads, query_length, key_length)
|
||||||
) # shape (1, num_heads, query_length, key_length)
|
return values.transpose(2, 0, 1)
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
@ -197,7 +201,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
TransformerEncoderLayer(config) for i in range(config.num_layers)
|
||||||
]
|
]
|
||||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.relative_attention_bias = RelativePositionBias(config)
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||||
@ -250,7 +254,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
TransformerDecoderLayer(config) for i in range(config.num_layers)
|
||||||
]
|
]
|
||||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.relative_attention_bias = RelativePositionBias(config)
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||||
|
|
||||||
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
@ -260,9 +264,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
T = offset + x.shape[1]
|
T = offset + x.shape[1]
|
||||||
# TODO, add offset to RelativePositionBias class to avoid wasted work
|
pos_bias = self.relative_attention_bias(T, T, offset=offset)
|
||||||
pos_bias = self.relative_attention_bias(T, T)
|
|
||||||
pos_bias = pos_bias[:, :, -x.shape[1]:, :]
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask += pos_bias
|
mask += pos_bias
|
||||||
else:
|
else:
|
||||||
@ -337,9 +339,7 @@ def generate(
|
|||||||
y = decoder_inputs
|
y = decoder_inputs
|
||||||
while True:
|
while True:
|
||||||
logits, cache = model.decode(y[None], memory, cache=cache)
|
logits, cache = model.decode(y[None], memory, cache=cache)
|
||||||
# logits, cache = model.decode(decoder_inputs[None], memory, cache=cache)
|
|
||||||
y = sample(logits[:, -1, :])
|
y = sample(logits[:, -1, :])
|
||||||
#decoder_inputs = mx.concatenate([decoder_inputs, y])
|
|
||||||
yield y.squeeze()
|
yield y.squeeze()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user