mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-12 23:34:37 +08:00
hard code freqs
This commit is contained in:
@@ -40,8 +40,7 @@ class SuScaledRotaryEmbedding(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||||
self._short_freqs = mx.array(short_factor, dtype=mx.float32) * freqs
|
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
|
||||||
self._long_freqs = mx.array(long_factor, dtype=mx.float32) * freqs
|
|
||||||
self.original_max_position_embeddings = original_max_position_embeddings
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
self.scale = math.sqrt(
|
self.scale = math.sqrt(
|
||||||
1
|
1
|
||||||
@@ -50,11 +49,6 @@ class SuScaledRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
def __call__(self, x, offset: int = 0):
|
||||||
freqs = (
|
|
||||||
self._long_freqs
|
|
||||||
if (offset + x.shape[2]) > self.original_max_position_embeddings
|
|
||||||
else self._short_freqs
|
|
||||||
)
|
|
||||||
return mx.fast.rope(
|
return mx.fast.rope(
|
||||||
self.scale * x,
|
self.scale * x,
|
||||||
x.shape[-1],
|
x.shape[-1],
|
||||||
@@ -62,5 +56,5 @@ class SuScaledRotaryEmbedding(nn.Module):
|
|||||||
base=None,
|
base=None,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
freqs=freqs,
|
freqs=self._freqs,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user