mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Use fast rope (#945)
* use fast rope * fix llama * use fast rope for llama3.1 * requires unreleased mlx * fix su * fix deepseek v2 * only one of base or freqs * nit * fix * hard code freqs
This commit is contained in:
@@ -68,13 +68,12 @@ def yarn_get_mscale(scale=1, mscale=1):
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def yarn_linear_ramp_mask(min, max, dim):
|
||||
if min == max:
|
||||
max += 0.001 # Prevent singularity
|
||||
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
||||
if min_val == max_val:
|
||||
max_val += 0.001 # Prevent singularity
|
||||
|
||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min)
|
||||
ramp_func = mx.clip(linear_func, 0, 1)
|
||||
return ramp_func
|
||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
|
||||
return mx.clip(linear_func, 0, 1)
|
||||
|
||||
|
||||
class DeepseekV2YarnRotaryEmbedding(nn.Module):
|
||||
@@ -91,72 +90,36 @@ class DeepseekV2YarnRotaryEmbedding(nn.Module):
|
||||
mscale_all_dim=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.scaling_factor = scaling_factor
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
self.mscale = mscale
|
||||
self.mscale_all_dim = mscale_all_dim
|
||||
|
||||
self.max_seq_len_cached = None
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._inv_freq = None
|
||||
self.set_cos_sin_cache(max_position_embeddings)
|
||||
|
||||
def set_cos_sin_cache(self, seq_len):
|
||||
self.max_seq_len_cached = seq_len
|
||||
dim = self.dim
|
||||
freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
|
||||
freq_inter = 1.0 / (
|
||||
self.scaling_factor
|
||||
* self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
|
||||
scaling_factor, mscale_all_dim
|
||||
)
|
||||
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||
freq_inter = scaling_factor * base ** (
|
||||
mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||
)
|
||||
|
||||
low, high = yarn_find_correction_range(
|
||||
self.beta_fast,
|
||||
self.beta_slow,
|
||||
beta_fast,
|
||||
beta_slow,
|
||||
dim,
|
||||
self.base,
|
||||
self.original_max_position_embeddings,
|
||||
base,
|
||||
original_max_position_embeddings,
|
||||
)
|
||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
||||
self._inv_freq = inv_freq
|
||||
|
||||
t = mx.arange(seq_len, dtype=mx.float32)
|
||||
freqs = mx.outer(t, inv_freq)
|
||||
|
||||
mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(
|
||||
self.scaling_factor, self.mscale_all_dim
|
||||
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||
self._freqs = (freq_inter * freq_extra) / (
|
||||
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
|
||||
)
|
||||
|
||||
self._cos_cached = mx.cos(freqs) * mscale
|
||||
self._sin_cached = mx.sin(freqs) * mscale
|
||||
|
||||
def apply_rotary_pos_emb(self, x, cos, sin):
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rx1 = x1 * cos - x2 * sin
|
||||
rx2 = x1 * sin + x2 * cos
|
||||
return mx.concatenate([rx1, rx2], axis=-1)
|
||||
|
||||
def __call__(self, x, offset=0):
|
||||
seq_len = offset + x.shape[2]
|
||||
if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
||||
self.set_cos_sin_cache(seq_len=seq_len)
|
||||
|
||||
if self._cos_cached.dtype != x.dtype:
|
||||
self._cos_cached = self._cos_cached.astype(x.dtype)
|
||||
self._sin_cached = self._sin_cached.astype(x.dtype)
|
||||
|
||||
return self.apply_rotary_pos_emb(
|
||||
if self.mscale != 1.0:
|
||||
x = self.mscale * x
|
||||
return mx.fast.rope(
|
||||
x,
|
||||
self._cos_cached[offset:seq_len],
|
||||
self._sin_cached[offset:seq_len],
|
||||
x.shape[-1],
|
||||
traditional=True,
|
||||
base=None,
|
||||
scale=1.0,
|
||||
offset=offset,
|
||||
freqs=self._freqs,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user