Add yarn option for qwen2

This commit is contained in:
Awni Hannun 2025-03-10 07:11:29 -07:00
parent d2e02b3aae
commit a81e8bcc2d
3 changed files with 101 additions and 21 deletions

View File

@ -7,6 +7,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass @dataclass
@ -18,24 +19,13 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int num_attention_heads: int
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
num_key_value_heads: Optional[int] = None num_key_value_heads: int
max_position_embeddings: int = 32768
rope_theta: float = 1000000 rope_theta: float = 1000000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
@ -54,16 +44,12 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = ( self.rope = initialize_rope(
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim, head_dim,
traditional=args.rope_traditional,
base=args.rope_theta, base=args.rope_theta,
scale=rope_scale, traditional=args.rope_traditional,
scaling_config=args.rope_scaling,
max_position_embeddings=args.max_position_embeddings,
) )
def __call__( def __call__(

View File

@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math
from typing import Optional from typing import Optional
import mlx.core as mx import mlx.core as mx
@ -61,6 +62,78 @@ class Llama3RoPE(nn.Module):
) )
class YarnRoPE(nn.Module):
def __init__(
self,
dims,
traditional=False,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
def yarn_find_correction_dim(num_rotations):
return (
dims
* math.log(
original_max_position_embeddings / (num_rotations * 2 * math.pi)
)
) / (2 * math.log(base))
def yarn_find_correction_range():
low = math.floor(yarn_find_correction_dim(beta_fast))
high = math.ceil(yarn_find_correction_dim(beta_slow))
return max(low, 0), min(high, dims - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (
max_val - min_val
)
return mx.clip(linear_func, 0, 1)
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
freq_inter = scaling_factor * base ** (
mx.arange(0, dims, 2, dtype=mx.float32) / dims
)
low, high = yarn_find_correction_range()
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dims // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
self.dims = dims
self.traditional = traditional
def __call__(self, x, offset=0):
if self.mscale != 1.0:
x[..., : self.dims] = self.mscale * x[..., : self.dims]
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
def initialize_rope( def initialize_rope(
dims, dims,
base, base,
@ -87,5 +160,25 @@ def initialize_rope(
base=base, base=base,
scaling_config=scaling_config, scaling_config=scaling_config,
) )
elif rope_type == "yarn":
scaling_factor = scaling_config["factor"]
rope_kwargs = {
key: scaling_config[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in scaling_config
}
return YarnRoPE(
dims=dims,
max_position_embeddings=max_position_embeddings,
traditional=traditional,
base=base,
**rope_kwargs,
)
else: else:
raise ValueError(f"Unsupported RoPE type {rope_type}") raise ValueError(f"Unsupported RoPE type {rope_type}")

View File

@ -336,6 +336,7 @@ class TestModels(unittest.TestCase):
num_hidden_layers=4, num_hidden_layers=4,
intermediate_size=2048, intermediate_size=2048,
num_attention_heads=4, num_attention_heads=4,
num_key_value_heads=4,
rms_norm_eps=1e-5, rms_norm_eps=1e-5,
vocab_size=10_000, vocab_size=10_000,
) )