mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-05 00:04:38 +08:00
fix su
This commit is contained in:
@@ -62,7 +62,7 @@ class SuScaledRotaryEmbedding(nn.Module):
|
||||
def __call__(self, x, offset: int = 0):
|
||||
freqs = (
|
||||
self._long_freqs
|
||||
if (offset + L) > self.original_max_position_embeddings
|
||||
if (offset + x.shape[2]) > self.original_max_position_embeddings
|
||||
else self._short_freqs
|
||||
)
|
||||
return mx.fast.rope(
|
||||
|
Reference in New Issue
Block a user