Fix relative_attention_max_distance config

This commit is contained in:
Juarez Bochi 2023-12-16 11:18:17 -05:00
parent 2a8ee32b02
commit 64e7eaccb8
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -18,6 +18,7 @@ class ModelArgs:
layer_norm_epsilon: float = 1e-06 layer_norm_epsilon: float = 1e-06
n_positions: int = 512 n_positions: int = 512
relative_attention_num_buckets: int = 32 relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
num_heads: int = 8 num_heads: int = 8
num_layers: int = 6 num_layers: int = 6
decoder_start_token_id: int = 0 decoder_start_token_id: int = 0
@ -61,7 +62,6 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
is_small = relative_position < max_exact is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
print("relative_position", relative_position)
relative_position_if_large = max_exact + ( relative_position_if_large = max_exact + (
mx.log(relative_position.astype(mx.float32) / max_exact) mx.log(relative_position.astype(mx.float32) / max_exact)
/ np.log(max_distance / max_exact) / np.log(max_distance / max_exact)
@ -78,7 +78,7 @@ class RelativePositionBias(nn.Module):
def __init__(self, config: ModelArgs, is_decoder: bool = False): def __init__(self, config: ModelArgs, is_decoder: bool = False):
self.bidirectional = not is_decoder self.bidirectional = not is_decoder
self.num_buckets = config.relative_attention_num_buckets self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.n_positions self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.embeddings = nn.Embedding( self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.relative_attention_num_buckets,