mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Use fast rope (#945)
* use fast rope * fix llama * use fast rope for llama3.1 * requires unreleased mlx * fix su * fix deepseek v2 * only one of base or freqs * nit * fix * hard code freqs
This commit is contained in:
parent
58591a1b41
commit
6731254e76
@ -68,13 +68,12 @@ def yarn_get_mscale(scale=1, mscale=1):
|
|||||||
return 0.1 * mscale * math.log(scale) + 1.0
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
|
||||||
def yarn_linear_ramp_mask(min, max, dim):
|
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
||||||
if min == max:
|
if min_val == max_val:
|
||||||
max += 0.001 # Prevent singularity
|
max_val += 0.001 # Prevent singularity
|
||||||
|
|
||||||
linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min)
|
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
|
||||||
ramp_func = mx.clip(linear_func, 0, 1)
|
return mx.clip(linear_func, 0, 1)
|
||||||
return ramp_func
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2YarnRotaryEmbedding(nn.Module):
|
class DeepseekV2YarnRotaryEmbedding(nn.Module):
|
||||||
@ -91,72 +90,36 @@ class DeepseekV2YarnRotaryEmbedding(nn.Module):
|
|||||||
mscale_all_dim=0,
|
mscale_all_dim=0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
|
||||||
self.max_position_embeddings = max_position_embeddings
|
scaling_factor, mscale_all_dim
|
||||||
self.base = base
|
)
|
||||||
self.scaling_factor = scaling_factor
|
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
||||||
self.original_max_position_embeddings = original_max_position_embeddings
|
freq_inter = scaling_factor * base ** (
|
||||||
self.beta_fast = beta_fast
|
mx.arange(0, dim, 2, dtype=mx.float32) / dim
|
||||||
self.beta_slow = beta_slow
|
|
||||||
self.mscale = mscale
|
|
||||||
self.mscale_all_dim = mscale_all_dim
|
|
||||||
|
|
||||||
self.max_seq_len_cached = None
|
|
||||||
self._cos_cached = None
|
|
||||||
self._sin_cached = None
|
|
||||||
self._inv_freq = None
|
|
||||||
self.set_cos_sin_cache(max_position_embeddings)
|
|
||||||
|
|
||||||
def set_cos_sin_cache(self, seq_len):
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
dim = self.dim
|
|
||||||
freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
|
|
||||||
freq_inter = 1.0 / (
|
|
||||||
self.scaling_factor
|
|
||||||
* self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
low, high = yarn_find_correction_range(
|
low, high = yarn_find_correction_range(
|
||||||
self.beta_fast,
|
beta_fast,
|
||||||
self.beta_slow,
|
beta_slow,
|
||||||
dim,
|
dim,
|
||||||
self.base,
|
base,
|
||||||
self.original_max_position_embeddings,
|
original_max_position_embeddings,
|
||||||
)
|
)
|
||||||
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
|
||||||
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
|
self._freqs = (freq_inter * freq_extra) / (
|
||||||
self._inv_freq = inv_freq
|
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
|
||||||
|
|
||||||
t = mx.arange(seq_len, dtype=mx.float32)
|
|
||||||
freqs = mx.outer(t, inv_freq)
|
|
||||||
|
|
||||||
mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(
|
|
||||||
self.scaling_factor, self.mscale_all_dim
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._cos_cached = mx.cos(freqs) * mscale
|
|
||||||
self._sin_cached = mx.sin(freqs) * mscale
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(self, x, cos, sin):
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
rx1 = x1 * cos - x2 * sin
|
|
||||||
rx2 = x1 * sin + x2 * cos
|
|
||||||
return mx.concatenate([rx1, rx2], axis=-1)
|
|
||||||
|
|
||||||
def __call__(self, x, offset=0):
|
def __call__(self, x, offset=0):
|
||||||
seq_len = offset + x.shape[2]
|
if self.mscale != 1.0:
|
||||||
if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
|
x = self.mscale * x
|
||||||
self.set_cos_sin_cache(seq_len=seq_len)
|
return mx.fast.rope(
|
||||||
|
|
||||||
if self._cos_cached.dtype != x.dtype:
|
|
||||||
self._cos_cached = self._cos_cached.astype(x.dtype)
|
|
||||||
self._sin_cached = self._sin_cached.astype(x.dtype)
|
|
||||||
|
|
||||||
return self.apply_rotary_pos_emb(
|
|
||||||
x,
|
x,
|
||||||
self._cos_cached[offset:seq_len],
|
x.shape[-1],
|
||||||
self._sin_cached[offset:seq_len],
|
traditional=True,
|
||||||
|
base=None,
|
||||||
|
scale=1.0,
|
||||||
|
offset=offset,
|
||||||
|
freqs=self._freqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,19 +65,16 @@ class DynamicNTKScalingRoPE(nn.Module):
|
|||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.traditional = traditional
|
self.traditional = traditional
|
||||||
self.original_base = base
|
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.rope_type = rope_type
|
self.rope_type = rope_type
|
||||||
self.rope_scaling = rope_scaling
|
self.rope_scaling = rope_scaling
|
||||||
self.base = self.compute_base_freq()
|
self.base = base
|
||||||
|
self.compute_freqs()
|
||||||
|
|
||||||
def compute_base_freq(self):
|
def compute_freqs(self):
|
||||||
if self.rope_type == "llama3":
|
if self.rope_type != "llama3":
|
||||||
return self.compute_llama3_base_freq()
|
self._freqs = None
|
||||||
return self.original_base
|
return
|
||||||
|
|
||||||
# source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
|
|
||||||
def compute_llama3_base_freq(self):
|
|
||||||
factor = self.rope_scaling["factor"]
|
factor = self.rope_scaling["factor"]
|
||||||
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
|
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
|
||||||
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
|
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
|
||||||
@ -89,19 +86,17 @@ class DynamicNTKScalingRoPE(nn.Module):
|
|||||||
low_freq_wavelen = old_context_len / low_freq_factor
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
high_freq_wavelen = old_context_len / high_freq_factor
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
|
||||||
freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
|
freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
|
||||||
wavelens = 2 * mx.pi * freqs
|
wavelens = 2 * mx.pi * freqs
|
||||||
new_base_freqs = []
|
|
||||||
|
|
||||||
smooths = (wavelens - high_freq_wavelen) / (
|
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
|
||||||
low_freq_wavelen - high_freq_wavelen
|
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
|
||||||
|
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
|
||||||
|
high_freq_factor - low_freq_factor
|
||||||
)
|
)
|
||||||
new_base_freqs = freqs * (1 - smooths) * factor + smooths
|
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
||||||
new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
|
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
||||||
new_base_freqs = mx.where(
|
self.base = None
|
||||||
wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
|
|
||||||
)
|
|
||||||
return new_base_freqs.mean().item()
|
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return (
|
return (
|
||||||
@ -111,20 +106,14 @@ class DynamicNTKScalingRoPE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
def __call__(self, x, offset: int = 0):
|
||||||
seq_len = x.shape[1] + offset
|
|
||||||
base = self.base
|
|
||||||
if self.max_position_embeddings and seq_len > self.max_position_embeddings:
|
|
||||||
base *= (
|
|
||||||
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
|
||||||
) ** (self.dims / (self.dims - 2))
|
|
||||||
|
|
||||||
return mx.fast.rope(
|
return mx.fast.rope(
|
||||||
x,
|
x,
|
||||||
self.dims,
|
self.dims,
|
||||||
traditional=self.traditional,
|
traditional=self.traditional,
|
||||||
base=base,
|
base=self.base,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
freqs=self._freqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,19 +59,17 @@ class Attention(nn.Module):
|
|||||||
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
rope_scale = 1.0
|
|
||||||
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
|
||||||
self.rope = SuScaledRotaryEmbedding(
|
self.rope = SuScaledRotaryEmbedding(
|
||||||
head_dim,
|
head_dim,
|
||||||
traditional=False,
|
|
||||||
base=args.rope_theta,
|
base=args.rope_theta,
|
||||||
scale=rope_scale,
|
|
||||||
max_position_embeddings=args.max_position_embeddings,
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
original_max_position_embeddings=args.original_max_position_embeddings,
|
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||||
short_factor=args.rope_scaling["short_factor"],
|
short_factor=args.rope_scaling["short_factor"],
|
||||||
long_factor=args.rope_scaling["long_factor"],
|
long_factor=args.rope_scaling["long_factor"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
rope_scale = 1.0
|
||||||
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||||
assert isinstance(args.rope_scaling["factor"], float)
|
assert isinstance(args.rope_scaling["factor"], float)
|
||||||
rope_scale = 1 / args.rope_scaling["factor"]
|
rope_scale = 1 / args.rope_scaling["factor"]
|
||||||
|
@ -4,15 +4,14 @@ import math
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class SuScaledRotaryEmbedding:
|
class SuScaledRotaryEmbedding(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims: int,
|
dims: int,
|
||||||
traditional: bool = False,
|
|
||||||
base: float = 10000.0,
|
base: float = 10000.0,
|
||||||
scale: float = 1.0,
|
|
||||||
max_position_embeddings: int = 131072,
|
max_position_embeddings: int = 131072,
|
||||||
original_max_position_embeddings: int = 4096,
|
original_max_position_embeddings: int = 4096,
|
||||||
short_factor: Union[List[float], float] = 1.0,
|
short_factor: Union[List[float], float] = 1.0,
|
||||||
@ -23,10 +22,7 @@ class SuScaledRotaryEmbedding:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
dims (int): The feature dimensions to be rotated.
|
dims (int): The feature dimensions to be rotated.
|
||||||
traditional (bool, optional): Unused. Default: ``False``.
|
|
||||||
base (int, optional): Base for the exponential scaling.
|
base (int, optional): Base for the exponential scaling.
|
||||||
scale (float, optional): The scale used to scale the positions.
|
|
||||||
Default: ``1.0``.
|
|
||||||
max_position_embeddings (int, optional): The maximum sequence
|
max_position_embeddings (int, optional): The maximum sequence
|
||||||
length that this model was trained with. This is used to determine
|
length that this model was trained with. This is used to determine
|
||||||
the size of the original RoPE embeddings when using long scaling.
|
the size of the original RoPE embeddings when using long scaling.
|
||||||
@ -42,40 +38,23 @@ class SuScaledRotaryEmbedding:
|
|||||||
factors for sequences of length greater than
|
factors for sequences of length greater than
|
||||||
``original_max_position_embeddings``. Default: ``1.0``.
|
``original_max_position_embeddings``. Default: ``1.0``.
|
||||||
"""
|
"""
|
||||||
self.inv_freq_short = 1.0 / (
|
super().__init__()
|
||||||
mx.array(short_factor, dtype=mx.float32)
|
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||||
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
|
||||||
)
|
|
||||||
self.inv_freq_long = 1.0 / (
|
|
||||||
scale
|
|
||||||
* mx.array(long_factor, dtype=mx.float32)
|
|
||||||
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
|
||||||
)
|
|
||||||
self.original_max_position_embeddings = original_max_position_embeddings
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
self.scaling_factor = math.sqrt(
|
self.scale = math.sqrt(
|
||||||
1
|
1
|
||||||
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
||||||
/ math.log(original_max_position_embeddings)
|
/ math.log(original_max_position_embeddings)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_cos_sin(self, offset, L):
|
|
||||||
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
|
|
||||||
inv_freq = (
|
|
||||||
self.inv_freq_long
|
|
||||||
if (offset + L) > self.original_max_position_embeddings
|
|
||||||
else self.inv_freq_short
|
|
||||||
)
|
|
||||||
freqs = position_ids[:, None] * inv_freq[None, :]
|
|
||||||
emb = mx.concatenate([freqs, freqs], axis=-1)
|
|
||||||
cos = mx.cos(emb) * self.scaling_factor
|
|
||||||
sin = mx.sin(emb) * self.scaling_factor
|
|
||||||
return cos, sin
|
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
def __call__(self, x, offset: int = 0):
|
||||||
def _rotate_half(_x):
|
return mx.fast.rope(
|
||||||
midpoint = _x.shape[-1] // 2
|
self.scale * x,
|
||||||
x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
|
x.shape[-1],
|
||||||
return mx.concatenate([-x2, x1], axis=-1)
|
traditional=False,
|
||||||
|
base=None,
|
||||||
cos, sin = self._get_cos_sin(offset, x.shape[2])
|
scale=1.0,
|
||||||
return (x * cos) + (_rotate_half(x) * sin)
|
offset=offset,
|
||||||
|
freqs=self._freqs,
|
||||||
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.14.1
|
mlx>=0.17.0
|
||||||
numpy
|
numpy
|
||||||
transformers[sentencepiece]>=4.39.3
|
transformers[sentencepiece]>=4.39.3
|
||||||
protobuf
|
protobuf
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.17.0"
|
__version__ = "0.17.1"
|
||||||
|
@ -110,7 +110,7 @@ class Transformer2D(nn.Module):
|
|||||||
|
|
||||||
# Perform the input norm and projection
|
# Perform the input norm and projection
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
|
x = self.norm(x).reshape(B, -1, C)
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
|
|
||||||
# Apply the transformer
|
# Apply the transformer
|
||||||
@ -156,12 +156,12 @@ class ResnetBlock2D(nn.Module):
|
|||||||
if temb is not None:
|
if temb is not None:
|
||||||
temb = self.time_emb_proj(nn.silu(temb))
|
temb = self.time_emb_proj(nn.silu(temb))
|
||||||
|
|
||||||
y = self.norm1(x.astype(mx.float32)).astype(dtype)
|
y = self.norm1(x)
|
||||||
y = nn.silu(y)
|
y = nn.silu(y)
|
||||||
y = self.conv1(y)
|
y = self.conv1(y)
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
y = y + temb[:, None, None, :]
|
y = y + temb[:, None, None, :]
|
||||||
y = self.norm2(y.astype(mx.float32)).astype(dtype)
|
y = self.norm2(y)
|
||||||
y = nn.silu(y)
|
y = nn.silu(y)
|
||||||
y = self.conv2(y)
|
y = self.conv2(y)
|
||||||
|
|
||||||
@ -453,8 +453,7 @@ class UNetModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Postprocess the output
|
# Postprocess the output
|
||||||
dtype = x.dtype
|
x = self.conv_norm_out(x)
|
||||||
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
|
|
||||||
x = nn.silu(x)
|
x = nn.silu(x)
|
||||||
x = self.conv_out(x)
|
x = self.conv_out(x)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user