fix llama

This commit is contained in:
Awni Hannun
2024-08-19 14:13:18 -07:00
parent 9ac9fa6798
commit 8b5c9ce6d2

View File

@@ -65,19 +65,15 @@ class DynamicNTKScalingRoPE(nn.Module):
self.dims = dims
self.max_position_embeddings = max_position_embeddings
self.traditional = traditional
self.original_base = base
self.scale = scale
self.rope_type = rope_type
self.rope_scaling = rope_scaling
self.base = self.compute_base_freq()
self.base = base
self.compute_freqs()
def compute_base_freq(self):
if self.rope_type == "llama3":
return self.compute_llama3_base_freq()
return self.original_base
# source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
def compute_llama3_base_freq(self):
def compute_freqs(self):
if self.rope_type != "llama3":
return
factor = self.rope_scaling["factor"]
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
@@ -89,19 +85,17 @@ class DynamicNTKScalingRoPE(nn.Module):
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
wavelens = 2 * mx.pi * freqs
new_base_freqs = []
smooths = (wavelens - high_freq_wavelen) / (
low_freq_wavelen - high_freq_wavelen
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_base_freqs = freqs * (1 - smooths) * factor + smooths
new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
new_base_freqs = mx.where(
wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
)
return new_base_freqs.mean().item()
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
self._inv_freqs = 1 / freqs
def extra_repr(self):
return (
@@ -111,18 +105,22 @@ class DynamicNTKScalingRoPE(nn.Module):
)
def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
base = self.base
if self.max_position_embeddings and seq_len > self.max_position_embeddings:
base *= (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
if "_inv_freqs" in self:
positions = mx.arange(offset, x.shape[2] + offset)[:, None]
freqs = positions * self._inv_freqs[None]
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb).astype(x.dtype)
sin = mx.sin(emb).astype(x.dtype)
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
rotated_x = mx.concatenate([-x2, x1], axis=-1)
return (x * cos) + (rotated_x * sin)
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=base,
base=self.base,
scale=self.scale,
offset=offset,
)