From 64e7eaccb8666fb4f5cbcdd6c935f93d3d908467 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Sat, 16 Dec 2023 11:18:17 -0500 Subject: [PATCH] Fix relative_attention_max_distance config --- t5/t5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/t5/t5.py b/t5/t5.py index e1a42ff0..155393bb 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -18,6 +18,7 @@ class ModelArgs: layer_norm_epsilon: float = 1e-06 n_positions: int = 512 relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 num_heads: int = 8 num_layers: int = 6 decoder_start_token_id: int = 0 @@ -61,7 +62,6 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - print("relative_position", relative_position) relative_position_if_large = max_exact + ( mx.log(relative_position.astype(mx.float32) / max_exact) / np.log(max_distance / max_exact) @@ -78,7 +78,7 @@ class RelativePositionBias(nn.Module): def __init__(self, config: ModelArgs, is_decoder: bool = False): self.bidirectional = not is_decoder self.num_buckets = config.relative_attention_num_buckets - self.max_distance = config.n_positions + self.max_distance = config.relative_attention_max_distance self.n_heads = config.num_heads self.embeddings = nn.Embedding( config.relative_attention_num_buckets,