hard code freqs

This commit is contained in:
Awni Hannun
2024-08-22 16:48:46 -07:00
parent 0a52a9d55a
commit 9aabf08b23

View File

@@ -40,8 +40,7 @@ class SuScaledRotaryEmbedding(nn.Module):
""" """
super().__init__() super().__init__()
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
self._short_freqs = mx.array(short_factor, dtype=mx.float32) * freqs self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self._long_freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self.original_max_position_embeddings = original_max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings
self.scale = math.sqrt( self.scale = math.sqrt(
1 1
@@ -50,11 +49,6 @@ class SuScaledRotaryEmbedding(nn.Module):
) )
def __call__(self, x, offset: int = 0): def __call__(self, x, offset: int = 0):
freqs = (
self._long_freqs
if (offset + x.shape[2]) > self.original_max_position_embeddings
else self._short_freqs
)
return mx.fast.rope( return mx.fast.rope(
self.scale * x, self.scale * x,
x.shape[-1], x.shape[-1],
@@ -62,5 +56,5 @@ class SuScaledRotaryEmbedding(nn.Module):
base=None, base=None,
scale=1.0, scale=1.0,
offset=offset, offset=offset,
freqs=freqs, freqs=self._freqs,
) )