From 83b68a5bdbe602ccfbc5c6f1a6ca27093b9de129 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Mon, 18 Dec 2023 11:13:44 -0500 Subject: [PATCH] Fix relative position scale --- t5/t5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/t5/t5.py b/t5/t5.py index 01119a75..d5ccdd9c 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -48,7 +48,7 @@ def _relative_position_bucket( is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - scale = np.log(max_distance / max_exact) * (num_buckets - max_exact) + scale = (num_buckets - max_exact) / np.log(max_distance / max_exact) relative_position_if_large = max_exact + ( mx.log(relative_position.astype(mx.float32) / max_exact) * scale ).astype(mx.int16)