mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
use fast rope for llama3.1
This commit is contained in:
@@ -94,8 +94,7 @@ class DynamicNTKScalingRoPE(nn.Module):
|
||||
high_freq_factor - low_freq_factor
|
||||
)
|
||||
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
||||
freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
||||
self._inv_freqs = 1 / freqs
|
||||
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
@@ -105,17 +104,6 @@ class DynamicNTKScalingRoPE(nn.Module):
|
||||
)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
if "_inv_freqs" in self:
|
||||
positions = mx.arange(offset, x.shape[2] + offset)[:, None]
|
||||
freqs = positions * self._inv_freqs[None]
|
||||
emb = mx.concatenate([freqs, freqs], axis=-1)
|
||||
cos = mx.cos(emb).astype(x.dtype)
|
||||
sin = mx.sin(emb).astype(x.dtype)
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
rotated_x = mx.concatenate([-x2, x1], axis=-1)
|
||||
return (x * cos) + (rotated_x * sin)
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
@@ -123,6 +111,7 @@ class DynamicNTKScalingRoPE(nn.Module):
|
||||
base=self.base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
freqs=self.get("_freqs", None),
|
||||
)
|
||||
|
||||
|
||||
|
@@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.17.0"
|
||||
__version__ = "0.17.1"
|
||||
|
Reference in New Issue
Block a user