From 9aabf08b235f6e2097df5be49e981723d2b96cae Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 22 Aug 2024 16:48:46 -0700 Subject: [PATCH] hard code freqs --- llms/mlx_lm/models/su_rope.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py index c75e9610..f96b9957 100644 --- a/llms/mlx_lm/models/su_rope.py +++ b/llms/mlx_lm/models/su_rope.py @@ -40,8 +40,7 @@ class SuScaledRotaryEmbedding(nn.Module): """ super().__init__() freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) - self._short_freqs = mx.array(short_factor, dtype=mx.float32) * freqs - self._long_freqs = mx.array(long_factor, dtype=mx.float32) * freqs + self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs self.original_max_position_embeddings = original_max_position_embeddings self.scale = math.sqrt( 1 @@ -50,11 +49,6 @@ class SuScaledRotaryEmbedding(nn.Module): ) def __call__(self, x, offset: int = 0): - freqs = ( - self._long_freqs - if (offset + x.shape[2]) > self.original_max_position_embeddings - else self._short_freqs - ) return mx.fast.rope( self.scale * x, x.shape[-1], @@ -62,5 +56,5 @@ class SuScaledRotaryEmbedding(nn.Module): base=None, scale=1.0, offset=offset, - freqs=freqs, + freqs=self._freqs, )