This commit is contained in:
Awni Hannun
2024-08-19 16:02:25 -07:00
parent fdc1c707c3
commit a3431ccc25

View File

@@ -62,7 +62,7 @@ class SuScaledRotaryEmbedding(nn.Module):
def __call__(self, x, offset: int = 0):
freqs = (
self._long_freqs
if (offset + L) > self.original_max_position_embeddings
if (offset + x.shape[2]) > self.original_max_position_embeddings
else self._short_freqs
)
return mx.fast.rope(