hard code freqs

This commit is contained in:
Awni Hannun
2024-08-22 16:48:46 -07:00
parent 0a52a9d55a
commit 9aabf08b23

View File

@@ -40,8 +40,7 @@ class SuScaledRotaryEmbedding(nn.Module):
"""
super().__init__()
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
self._short_freqs = mx.array(short_factor, dtype=mx.float32) * freqs
self._long_freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self.original_max_position_embeddings = original_max_position_embeddings
self.scale = math.sqrt(
1
@@ -50,11 +49,6 @@ class SuScaledRotaryEmbedding(nn.Module):
)
def __call__(self, x, offset: int = 0):
freqs = (
self._long_freqs
if (offset + x.shape[2]) > self.original_max_position_embeddings
else self._short_freqs
)
return mx.fast.rope(
self.scale * x,
x.shape[-1],
@@ -62,5 +56,5 @@ class SuScaledRotaryEmbedding(nn.Module):
base=None,
scale=1.0,
offset=offset,
freqs=freqs,
freqs=self._freqs,
)