mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Fix relative_attention_max_distance config
This commit is contained in:
parent
2a8ee32b02
commit
64e7eaccb8
4
t5/t5.py
4
t5/t5.py
@ -18,6 +18,7 @@ class ModelArgs:
|
|||||||
layer_norm_epsilon: float = 1e-06
|
layer_norm_epsilon: float = 1e-06
|
||||||
n_positions: int = 512
|
n_positions: int = 512
|
||||||
relative_attention_num_buckets: int = 32
|
relative_attention_num_buckets: int = 32
|
||||||
|
relative_attention_max_distance: int = 128
|
||||||
num_heads: int = 8
|
num_heads: int = 8
|
||||||
num_layers: int = 6
|
num_layers: int = 6
|
||||||
decoder_start_token_id: int = 0
|
decoder_start_token_id: int = 0
|
||||||
@ -61,7 +62,6 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
|
|||||||
is_small = relative_position < max_exact
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
print("relative_position", relative_position)
|
|
||||||
relative_position_if_large = max_exact + (
|
relative_position_if_large = max_exact + (
|
||||||
mx.log(relative_position.astype(mx.float32) / max_exact)
|
mx.log(relative_position.astype(mx.float32) / max_exact)
|
||||||
/ np.log(max_distance / max_exact)
|
/ np.log(max_distance / max_exact)
|
||||||
@ -78,7 +78,7 @@ class RelativePositionBias(nn.Module):
|
|||||||
def __init__(self, config: ModelArgs, is_decoder: bool = False):
|
def __init__(self, config: ModelArgs, is_decoder: bool = False):
|
||||||
self.bidirectional = not is_decoder
|
self.bidirectional = not is_decoder
|
||||||
self.num_buckets = config.relative_attention_num_buckets
|
self.num_buckets = config.relative_attention_num_buckets
|
||||||
self.max_distance = config.n_positions
|
self.max_distance = config.relative_attention_max_distance
|
||||||
self.n_heads = config.num_heads
|
self.n_heads = config.num_heads
|
||||||
self.embeddings = nn.Embedding(
|
self.embeddings = nn.Embedding(
|
||||||
config.relative_attention_num_buckets,
|
config.relative_attention_num_buckets,
|
||||||
|
Loading…
Reference in New Issue
Block a user