mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Fix relative_attention_max_distance config
This commit is contained in:
parent
2a8ee32b02
commit
64e7eaccb8
4
t5/t5.py
4
t5/t5.py
@ -18,6 +18,7 @@ class ModelArgs:
|
||||
layer_norm_epsilon: float = 1e-06
|
||||
n_positions: int = 512
|
||||
relative_attention_num_buckets: int = 32
|
||||
relative_attention_max_distance: int = 128
|
||||
num_heads: int = 8
|
||||
num_layers: int = 6
|
||||
decoder_start_token_id: int = 0
|
||||
@ -61,7 +62,6 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
print("relative_position", relative_position)
|
||||
relative_position_if_large = max_exact + (
|
||||
mx.log(relative_position.astype(mx.float32) / max_exact)
|
||||
/ np.log(max_distance / max_exact)
|
||||
@ -78,7 +78,7 @@ class RelativePositionBias(nn.Module):
|
||||
def __init__(self, config: ModelArgs, is_decoder: bool = False):
|
||||
self.bidirectional = not is_decoder
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.n_positions
|
||||
self.max_distance = config.relative_attention_max_distance
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(
|
||||
config.relative_attention_num_buckets,
|
||||
|
Loading…
Reference in New Issue
Block a user