diff --git a/t5/t5.py b/t5/t5.py index 01119a75..d5ccdd9c 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -48,7 +48,7 @@ def _relative_position_bucket( is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - scale = np.log(max_distance / max_exact) * (num_buckets - max_exact) + scale = (num_buckets - max_exact) / np.log(max_distance / max_exact) relative_position_if_large = max_exact + ( mx.log(relative_position.astype(mx.float32) / max_exact) * scale ).astype(mx.int16)