mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Use fast rope (#945)
* use fast rope * fix llama * use fast rope for llama3.1 * requires unreleased mlx * fix su * fix deepseek v2 * only one of base or freqs * nit * fix * hard code freqs
This commit is contained in:
@@ -65,19 +65,16 @@ class DynamicNTKScalingRoPE(nn.Module):
|
||||
self.dims = dims
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.traditional = traditional
|
||||
self.original_base = base
|
||||
self.scale = scale
|
||||
self.rope_type = rope_type
|
||||
self.rope_scaling = rope_scaling
|
||||
self.base = self.compute_base_freq()
|
||||
self.base = base
|
||||
self.compute_freqs()
|
||||
|
||||
def compute_base_freq(self):
|
||||
if self.rope_type == "llama3":
|
||||
return self.compute_llama3_base_freq()
|
||||
return self.original_base
|
||||
|
||||
# source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
|
||||
def compute_llama3_base_freq(self):
|
||||
def compute_freqs(self):
|
||||
if self.rope_type != "llama3":
|
||||
self._freqs = None
|
||||
return
|
||||
factor = self.rope_scaling["factor"]
|
||||
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
|
||||
@@ -89,19 +86,17 @@ class DynamicNTKScalingRoPE(nn.Module):
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
|
||||
freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
|
||||
wavelens = 2 * mx.pi * freqs
|
||||
new_base_freqs = []
|
||||
|
||||
smooths = (wavelens - high_freq_wavelen) / (
|
||||
low_freq_wavelen - high_freq_wavelen
|
||||
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
|
||||
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
|
||||
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
|
||||
high_freq_factor - low_freq_factor
|
||||
)
|
||||
new_base_freqs = freqs * (1 - smooths) * factor + smooths
|
||||
new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
|
||||
new_base_freqs = mx.where(
|
||||
wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
|
||||
)
|
||||
return new_base_freqs.mean().item()
|
||||
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
||||
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
||||
self.base = None
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
@@ -111,20 +106,14 @@ class DynamicNTKScalingRoPE(nn.Module):
|
||||
)
|
||||
|
||||
def __call__(self, x, offset: int = 0):
|
||||
seq_len = x.shape[1] + offset
|
||||
base = self.base
|
||||
if self.max_position_embeddings and seq_len > self.max_position_embeddings:
|
||||
base *= (
|
||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||
) ** (self.dims / (self.dims - 2))
|
||||
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self.dims,
|
||||
traditional=self.traditional,
|
||||
base=base,
|
||||
base=self.base,
|
||||
scale=self.scale,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user