From 3822d6bfc3ccd8f5de8bffb04e7ad5acd2d57909 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 19 Aug 2024 14:15:39 -0700 Subject: [PATCH] use fast rope for llama3.1 --- llms/mlx_lm/models/llama.py | 15 ++------------- llms/mlx_lm/version.py | 2 +- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 553fc0a4..7c6914b9 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -94,8 +94,7 @@ class DynamicNTKScalingRoPE(nn.Module): high_freq_factor - low_freq_factor ) smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) - freqs = mx.where(is_medium_freq, smooth_freqs, freqs) - self._inv_freqs = 1 / freqs + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) def extra_repr(self): return ( @@ -105,17 +104,6 @@ class DynamicNTKScalingRoPE(nn.Module): ) def __call__(self, x, offset: int = 0): - if "_inv_freqs" in self: - positions = mx.arange(offset, x.shape[2] + offset)[:, None] - freqs = positions * self._inv_freqs[None] - emb = mx.concatenate([freqs, freqs], axis=-1) - cos = mx.cos(emb).astype(x.dtype) - sin = mx.sin(emb).astype(x.dtype) - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - rotated_x = mx.concatenate([-x2, x1], axis=-1) - return (x * cos) + (rotated_x * sin) - return mx.fast.rope( x, self.dims, @@ -123,6 +111,7 @@ class DynamicNTKScalingRoPE(nn.Module): base=self.base, scale=self.scale, offset=offset, + freqs=self.get("_freqs", None), ) diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index f73aaa0a..41237905 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.17.0" +__version__ = "0.17.1"