diff --git a/t5/t5.py b/t5/t5.py index 38fea821..7196249b 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -58,7 +58,7 @@ def _relative_position_bucket( relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets relative_position = mx.abs(relative_position) else: - relative_position = -mx.min(relative_position, mx.zeros_like(relative_position)) + relative_position = -mx.minimum(relative_position, mx.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions @@ -79,8 +79,8 @@ def _relative_position_bucket( class RelativePositionBias(nn.Module): - def __init__(self, config: ModelArgs, is_decoder: bool = False): - self.bidirectional = not is_decoder + def __init__(self, config: ModelArgs, bidirectional: bool): + self.bidirectional = False #bidirectional self.num_buckets = config.relative_attention_num_buckets self.max_distance = config.relative_attention_max_distance self.n_heads = config.num_heads @@ -88,26 +88,30 @@ class RelativePositionBias(nn.Module): config.relative_attention_num_buckets, config.num_heads ) - def __call__(self, query_length, key_length): + def __call__(self, query_length: int, key_length: int, offset: int = 0): """Compute binned relative position bias""" - context_position = mx.arange(query_length, dtype=mx.int32)[:, None] - memory_position = mx.arange(key_length, dtype=mx.int32)[None, :] + context_position = mx.arange(offset, query_length)[:, None] + memory_position = mx.arange(key_length)[None, :] + + + # shape (query_length, key_length) relative_position = ( memory_position - context_position - ) # shape (query_length, key_length) + ) relative_position_bucket = _relative_position_bucket( - relative_position, # shape (query_length, key_length) + relative_position, bidirectional=self.bidirectional, num_buckets=self.num_buckets, max_distance=self.max_distance, ) + + # shape (query_length, key_length, num_heads) values = self.embeddings( relative_position_bucket - ) # shape (query_length, key_length, num_heads) - values = mx.expand_dims( - values.transpose(2, 0, 1), 0 - ) # shape (1, num_heads, query_length, key_length) - return values + ) + + # shape (num_heads, query_length, key_length) + return values.transpose(2, 0, 1) class MultiHeadAttention(nn.Module): @@ -197,7 +201,7 @@ class TransformerEncoder(nn.Module): TransformerEncoderLayer(config) for i in range(config.num_layers) ] self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.relative_attention_bias = RelativePositionBias(config) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) def __call__(self, x): pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) @@ -250,7 +254,7 @@ class TransformerDecoder(nn.Module): TransformerDecoderLayer(config) for i in range(config.num_layers) ] self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.relative_attention_bias = RelativePositionBias(config) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) def __call__(self, x, memory, mask, memory_mask, cache=None): if cache is not None: @@ -260,9 +264,7 @@ class TransformerDecoder(nn.Module): cache = [None] * len(self.layers) T = offset + x.shape[1] - # TODO, add offset to RelativePositionBias class to avoid wasted work - pos_bias = self.relative_attention_bias(T, T) - pos_bias = pos_bias[:, :, -x.shape[1]:, :] + pos_bias = self.relative_attention_bias(T, T, offset=offset) if mask is not None: mask += pos_bias else: @@ -337,9 +339,7 @@ def generate( y = decoder_inputs while True: logits, cache = model.decode(y[None], memory, cache=cache) - # logits, cache = model.decode(decoder_inputs[None], memory, cache=cache) y = sample(logits[:, -1, :]) - #decoder_inputs = mx.concatenate([decoder_inputs, y]) yield y.squeeze()