From 8b5c9ce6d2f71a562ea5a37e420b2680ceec9982 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 19 Aug 2024 14:13:18 -0700 Subject: [PATCH] fix llama --- llms/mlx_lm/models/llama.py | 50 ++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 192e591f..553fc0a4 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -65,19 +65,15 @@ class DynamicNTKScalingRoPE(nn.Module): self.dims = dims self.max_position_embeddings = max_position_embeddings self.traditional = traditional - self.original_base = base self.scale = scale self.rope_type = rope_type self.rope_scaling = rope_scaling - self.base = self.compute_base_freq() + self.base = base + self.compute_freqs() - def compute_base_freq(self): - if self.rope_type == "llama3": - return self.compute_llama3_base_freq() - return self.original_base - - # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318 - def compute_llama3_base_freq(self): + def compute_freqs(self): + if self.rope_type != "llama3": + return factor = self.rope_scaling["factor"] low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) @@ -89,19 +85,17 @@ class DynamicNTKScalingRoPE(nn.Module): low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor - freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims) + freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) wavelens = 2 * mx.pi * freqs - new_base_freqs = [] - smooths = (wavelens - high_freq_wavelen) / ( - low_freq_wavelen - high_freq_wavelen + freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( + high_freq_factor - low_freq_factor ) - new_base_freqs = freqs * (1 - smooths) * factor + smooths - new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs) - new_base_freqs = mx.where( - wavelens > low_freq_wavelen, freqs * factor, new_base_freqs - ) - return new_base_freqs.mean().item() + smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + self._inv_freqs = 1 / freqs def extra_repr(self): return ( @@ -111,18 +105,22 @@ class DynamicNTKScalingRoPE(nn.Module): ) def __call__(self, x, offset: int = 0): - seq_len = x.shape[1] + offset - base = self.base - if self.max_position_embeddings and seq_len > self.max_position_embeddings: - base *= ( - (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1) - ) ** (self.dims / (self.dims - 2)) + if "_inv_freqs" in self: + positions = mx.arange(offset, x.shape[2] + offset)[:, None] + freqs = positions * self._inv_freqs[None] + emb = mx.concatenate([freqs, freqs], axis=-1) + cos = mx.cos(emb).astype(x.dtype) + sin = mx.sin(emb).astype(x.dtype) + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + rotated_x = mx.concatenate([-x2, x1], axis=-1) + return (x * cos) + (rotated_x * sin) return mx.fast.rope( x, self.dims, traditional=self.traditional, - base=base, + base=self.base, scale=self.scale, offset=offset, )