Fix relative position scale

This commit is contained in:
Juarez Bochi 2023-12-18 11:13:44 -05:00
parent 9d3ee016c9
commit 83b68a5bdb
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6

View File

@ -48,7 +48,7 @@ def _relative_position_bucket(
is_small = relative_position < max_exact is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
scale = np.log(max_distance / max_exact) * (num_buckets - max_exact) scale = (num_buckets - max_exact) / np.log(max_distance / max_exact)
relative_position_if_large = max_exact + ( relative_position_if_large = max_exact + (
mx.log(relative_position.astype(mx.float32) / max_exact) * scale mx.log(relative_position.astype(mx.float32) / max_exact) * scale
).astype(mx.int16) ).astype(mx.int16)