use fast rope for llama3.1

This commit is contained in:
Awni Hannun
2024-08-19 14:15:39 -07:00
parent 8b5c9ce6d2
commit 3822d6bfc3
2 changed files with 3 additions and 14 deletions

View File

@@ -94,8 +94,7 @@ class DynamicNTKScalingRoPE(nn.Module):
high_freq_factor - low_freq_factor
)
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
self._inv_freqs = 1 / freqs
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
def extra_repr(self):
return (
@@ -105,17 +104,6 @@ class DynamicNTKScalingRoPE(nn.Module):
)
def __call__(self, x, offset: int = 0):
if "_inv_freqs" in self:
positions = mx.arange(offset, x.shape[2] + offset)[:, None]
freqs = positions * self._inv_freqs[None]
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb).astype(x.dtype)
sin = mx.sin(emb).astype(x.dtype)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
rotated_x = mx.concatenate([-x2, x1], axis=-1)
return (x * cos) + (rotated_x * sin)
return mx.fast.rope(
x,
self.dims,
@@ -123,6 +111,7 @@ class DynamicNTKScalingRoPE(nn.Module):
base=self.base,
scale=self.scale,
offset=offset,
freqs=self.get("_freqs", None),
)

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.17.0"
__version__ = "0.17.1"