mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Add yarn option for qwen2
This commit is contained in:
parent
d2e02b3aae
commit
a81e8bcc2d
@ -7,6 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
from .rope_utils import initialize_rope
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -18,24 +19,13 @@ class ModelArgs(BaseModelArgs):
|
|||||||
num_attention_heads: int
|
num_attention_heads: int
|
||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: Optional[int] = None
|
num_key_value_heads: int
|
||||||
|
max_position_embeddings: int = 32768
|
||||||
rope_theta: float = 1000000
|
rope_theta: float = 1000000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
tie_word_embeddings: bool = True
|
tie_word_embeddings: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.num_key_value_heads is None:
|
|
||||||
self.num_key_value_heads = self.num_attention_heads
|
|
||||||
|
|
||||||
if self.rope_scaling:
|
|
||||||
required_keys = {"factor", "type"}
|
|
||||||
if not all(key in self.rope_scaling for key in required_keys):
|
|
||||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
|
||||||
|
|
||||||
if self.rope_scaling["type"] != "linear":
|
|
||||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
@ -54,16 +44,12 @@ class Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
rope_scale = (
|
self.rope = initialize_rope(
|
||||||
1 / args.rope_scaling["factor"]
|
|
||||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
|
||||||
else 1
|
|
||||||
)
|
|
||||||
self.rope = nn.RoPE(
|
|
||||||
head_dim,
|
head_dim,
|
||||||
traditional=args.rope_traditional,
|
|
||||||
base=args.rope_theta,
|
base=args.rope_theta,
|
||||||
scale=rope_scale,
|
traditional=args.rope_traditional,
|
||||||
|
scaling_config=args.rope_scaling,
|
||||||
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -61,6 +62,78 @@ class Llama3RoPE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class YarnRoPE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims,
|
||||||
|
traditional=False,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
base=10000,
|
||||||
|
scaling_factor=1.0,
|
||||||
|
original_max_position_embeddings=4096,
|
||||||
|
beta_fast=32,
|
||||||
|
beta_slow=1,
|
||||||
|
mscale=1,
|
||||||
|
mscale_all_dim=0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def yarn_find_correction_dim(num_rotations):
|
||||||
|
return (
|
||||||
|
dims
|
||||||
|
* math.log(
|
||||||
|
original_max_position_embeddings / (num_rotations * 2 * math.pi)
|
||||||
|
)
|
||||||
|
) / (2 * math.log(base))
|
||||||
|
|
||||||
|
def yarn_find_correction_range():
|
||||||
|
low = math.floor(yarn_find_correction_dim(beta_fast))
|
||||||
|
high = math.ceil(yarn_find_correction_dim(beta_slow))
|
||||||
|
return max(low, 0), min(high, dims - 1)
|
||||||
|
|
||||||
|
def yarn_get_mscale(scale=1, mscale=1):
|
||||||
|
if scale <= 1:
|
||||||
|
return 1.0
|
||||||
|
return 0.1 * mscale * math.log(scale) + 1.0
|
||||||
|
|
||||||
|
def yarn_linear_ramp_mask(min_val, max_val, dim):
|
||||||
|
if min_val == max_val:
|
||||||
|
max_val += 0.001 # Prevent singularity
|
||||||
|
|
||||||
|
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (
|
||||||
|
max_val - min_val
|
||||||
|
)
|
||||||
|
return mx.clip(linear_func, 0, 1)
|
||||||
|
|
||||||
|
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
|
||||||
|
scaling_factor, mscale_all_dim
|
||||||
|
)
|
||||||
|
freq_extra = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||||
|
freq_inter = scaling_factor * base ** (
|
||||||
|
mx.arange(0, dims, 2, dtype=mx.float32) / dims
|
||||||
|
)
|
||||||
|
low, high = yarn_find_correction_range()
|
||||||
|
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dims // 2)
|
||||||
|
self._freqs = (freq_inter * freq_extra) / (
|
||||||
|
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
|
||||||
|
)
|
||||||
|
self.dims = dims
|
||||||
|
self.traditional = traditional
|
||||||
|
|
||||||
|
def __call__(self, x, offset=0):
|
||||||
|
if self.mscale != 1.0:
|
||||||
|
x[..., : self.dims] = self.mscale * x[..., : self.dims]
|
||||||
|
return mx.fast.rope(
|
||||||
|
x,
|
||||||
|
self.dims,
|
||||||
|
traditional=self.traditional,
|
||||||
|
base=None,
|
||||||
|
scale=1.0,
|
||||||
|
offset=offset,
|
||||||
|
freqs=self._freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def initialize_rope(
|
def initialize_rope(
|
||||||
dims,
|
dims,
|
||||||
base,
|
base,
|
||||||
@ -87,5 +160,25 @@ def initialize_rope(
|
|||||||
base=base,
|
base=base,
|
||||||
scaling_config=scaling_config,
|
scaling_config=scaling_config,
|
||||||
)
|
)
|
||||||
|
elif rope_type == "yarn":
|
||||||
|
scaling_factor = scaling_config["factor"]
|
||||||
|
rope_kwargs = {
|
||||||
|
key: scaling_config[key]
|
||||||
|
for key in [
|
||||||
|
"original_max_position_embeddings",
|
||||||
|
"beta_fast",
|
||||||
|
"beta_slow",
|
||||||
|
"mscale",
|
||||||
|
"mscale_all_dim",
|
||||||
|
]
|
||||||
|
if key in scaling_config
|
||||||
|
}
|
||||||
|
return YarnRoPE(
|
||||||
|
dims=dims,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
traditional=traditional,
|
||||||
|
base=base,
|
||||||
|
**rope_kwargs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RoPE type {rope_type}")
|
raise ValueError(f"Unsupported RoPE type {rope_type}")
|
||||||
|
@ -336,6 +336,7 @@ class TestModels(unittest.TestCase):
|
|||||||
num_hidden_layers=4,
|
num_hidden_layers=4,
|
||||||
intermediate_size=2048,
|
intermediate_size=2048,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
|
num_key_value_heads=4,
|
||||||
rms_norm_eps=1e-5,
|
rms_norm_eps=1e-5,
|
||||||
vocab_size=10_000,
|
vocab_size=10_000,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user