bug fix with bidirectional only for encoder, add offset to position bias

This commit is contained in:
Awni Hannun 2023-12-17 21:22:00 -08:00
parent 688a6e1e78
commit c468edc4e3

View File

@ -58,7 +58,7 @@ def _relative_position_bucket(
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
relative_position = mx.abs(relative_position) relative_position = mx.abs(relative_position)
else: else:
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position)) relative_position = -mx.minimum(relative_position, mx.zeros_like(relative_position))
# now relative_position is in the range [0, inf) # now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions # half of the buckets are for exact increments in positions
@ -79,8 +79,8 @@ def _relative_position_bucket(
class RelativePositionBias(nn.Module): class RelativePositionBias(nn.Module):
def __init__(self, config: ModelArgs, is_decoder: bool = False): def __init__(self, config: ModelArgs, bidirectional: bool):
self.bidirectional = not is_decoder self.bidirectional = False #bidirectional
self.num_buckets = config.relative_attention_num_buckets self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads self.n_heads = config.num_heads
@ -88,26 +88,30 @@ class RelativePositionBias(nn.Module):
config.relative_attention_num_buckets, config.num_heads config.relative_attention_num_buckets, config.num_heads
) )
def __call__(self, query_length, key_length): def __call__(self, query_length: int, key_length: int, offset: int = 0):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
context_position = mx.arange(query_length, dtype=mx.int32)[:, None] context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :] memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = ( relative_position = (
memory_position - context_position memory_position - context_position
) # shape (query_length, key_length) )
relative_position_bucket = _relative_position_bucket( relative_position_bucket = _relative_position_bucket(
relative_position, # shape (query_length, key_length) relative_position,
bidirectional=self.bidirectional, bidirectional=self.bidirectional,
num_buckets=self.num_buckets, num_buckets=self.num_buckets,
max_distance=self.max_distance, max_distance=self.max_distance,
) )
# shape (query_length, key_length, num_heads)
values = self.embeddings( values = self.embeddings(
relative_position_bucket relative_position_bucket
) # shape (query_length, key_length, num_heads) )
values = mx.expand_dims(
values.transpose(2, 0, 1), 0 # shape (num_heads, query_length, key_length)
) # shape (1, num_heads, query_length, key_length) return values.transpose(2, 0, 1)
return values
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
@ -197,7 +201,7 @@ class TransformerEncoder(nn.Module):
TransformerEncoderLayer(config) for i in range(config.num_layers) TransformerEncoderLayer(config) for i in range(config.num_layers)
] ]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config) self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x): def __call__(self, x):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
@ -250,7 +254,7 @@ class TransformerDecoder(nn.Module):
TransformerDecoderLayer(config) for i in range(config.num_layers) TransformerDecoderLayer(config) for i in range(config.num_layers)
] ]
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config) self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
def __call__(self, x, memory, mask, memory_mask, cache=None): def __call__(self, x, memory, mask, memory_mask, cache=None):
if cache is not None: if cache is not None:
@ -260,9 +264,7 @@ class TransformerDecoder(nn.Module):
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
T = offset + x.shape[1] T = offset + x.shape[1]
# TODO, add offset to RelativePositionBias class to avoid wasted work pos_bias = self.relative_attention_bias(T, T, offset=offset)
pos_bias = self.relative_attention_bias(T, T)
pos_bias = pos_bias[:, :, -x.shape[1]:, :]
if mask is not None: if mask is not None:
mask += pos_bias mask += pos_bias
else: else:
@ -337,9 +339,7 @@ def generate(
y = decoder_inputs y = decoder_inputs
while True: while True:
logits, cache = model.decode(y[None], memory, cache=cache) logits, cache = model.decode(y[None], memory, cache=cache)
# logits, cache = model.decode(decoder_inputs[None], memory, cache=cache)
y = sample(logits[:, -1, :]) y = sample(logits[:, -1, :])
#decoder_inputs = mx.concatenate([decoder_inputs, y])
yield y.squeeze() yield y.squeeze()