mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Su-RoPE(Rotary Position Embedding) for Phi-3 (#813)
* Su-RoPE * nits * Update su_rope.py * Update su_rope.py Per GPT4: "The error TypeError: 'type' object is not subscriptable is caused by using the type hint list[float] in a version of Python that does not support it. This syntax is only available in Python 3.9 and later." * Ran isort --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
a54dfd698e
commit
fda41545a6
@ -5,6 +5,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs
|
from .base import BaseModelArgs
|
||||||
|
from .su_rope import SuScaledRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -20,6 +21,8 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||||
|
max_position_embeddings: int = 131072
|
||||||
|
original_max_position_embeddings: int = 4096
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.num_key_value_heads is None:
|
if self.num_key_value_heads is None:
|
||||||
@ -30,9 +33,9 @@ class ModelArgs(BaseModelArgs):
|
|||||||
if not all(key in self.rope_scaling for key in required_keys):
|
if not all(key in self.rope_scaling for key in required_keys):
|
||||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||||
|
|
||||||
if self.rope_scaling["type"] != "linear":
|
if self.rope_scaling["type"] not in ["su", "linear"]:
|
||||||
print(
|
print(
|
||||||
"[WARNING] rope_scaling 'type' currently only supports 'linear' setting rope scaling to false."
|
"[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false."
|
||||||
)
|
)
|
||||||
self.rope_scaling = None
|
self.rope_scaling = None
|
||||||
|
|
||||||
@ -53,17 +56,27 @@ class Attention(nn.Module):
|
|||||||
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||||
|
|
||||||
rope_scale = (
|
rope_scale = 1.0
|
||||||
1 / args.rope_scaling["factor"]
|
if args.rope_scaling and args.rope_scaling["type"] == "su":
|
||||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
self.rope = SuScaledRotaryEmbedding(
|
||||||
else 1
|
head_dim,
|
||||||
)
|
traditional=False,
|
||||||
self.rope = nn.RoPE(
|
base=args.rope_theta,
|
||||||
head_dim,
|
scale=rope_scale,
|
||||||
traditional=args.rope_traditional,
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
base=args.rope_theta,
|
original_max_position_embeddings=args.original_max_position_embeddings,
|
||||||
scale=rope_scale,
|
short_factor=args.rope_scaling["short_factor"],
|
||||||
)
|
long_factor=args.rope_scaling["long_factor"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if args.rope_scaling and args.rope_scaling["type"] == "linear":
|
||||||
|
rope_scale = 1 / args.rope_scaling["factor"]
|
||||||
|
self.rope = nn.RoPE(
|
||||||
|
head_dim,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
79
llms/mlx_lm/models/su_rope.py
Normal file
79
llms/mlx_lm/models/su_rope.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
class SuScaledRotaryEmbedding:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
traditional: bool = False,
|
||||||
|
base: float = 10000.0,
|
||||||
|
scale: float = 1.0,
|
||||||
|
max_position_embeddings: int = 131072,
|
||||||
|
original_max_position_embeddings: int = 4096,
|
||||||
|
short_factor: Union[List[float], float] = 1.0,
|
||||||
|
long_factor: Union[List[float], float] = 1.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dims (int): The feature dimensions to be rotated.
|
||||||
|
traditional (bool, optional): Unused. Default: ``False``.
|
||||||
|
base (int, optional): Base for the exponential scaling.
|
||||||
|
scale (float, optional): The scale used to scale the positions.
|
||||||
|
Default: ``1.0``.
|
||||||
|
max_position_embeddings (int, optional): The maximum sequence
|
||||||
|
length that this model was trained with. This is used to determine
|
||||||
|
the size of the original RoPE embeddings when using long scaling.
|
||||||
|
Default: ``131072``.
|
||||||
|
original_max_position_embeddings (int, optional): The maximum
|
||||||
|
sequence length that this model was trained with. This is used to
|
||||||
|
determine the size of the original RoPE embeddings when using long
|
||||||
|
scaling. Default: ``4096``.
|
||||||
|
short_factor (float or list[float], optional): List of scaling
|
||||||
|
factors for sequences of length lesser than
|
||||||
|
``original_max_position_embeddings``. Default: ``1.0``.
|
||||||
|
long_factor (float or list[float], optional): List of scaling
|
||||||
|
factors for sequences of length greater than
|
||||||
|
``original_max_position_embeddings``. Default: ``1.0``.
|
||||||
|
"""
|
||||||
|
self.inv_freq_short = 1.0 / (
|
||||||
|
mx.array(short_factor, dtype=mx.float32)
|
||||||
|
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||||
|
)
|
||||||
|
self.inv_freq_long = 1.0 / (
|
||||||
|
scale
|
||||||
|
* mx.array(long_factor, dtype=mx.float32)
|
||||||
|
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
||||||
|
)
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
self.scaling_factor = math.sqrt(
|
||||||
|
1
|
||||||
|
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
||||||
|
/ math.log(original_max_position_embeddings)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_cos_sin(self, offset, L):
|
||||||
|
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
|
||||||
|
inv_freq = (
|
||||||
|
self.inv_freq_long
|
||||||
|
if (offset + L) > self.original_max_position_embeddings
|
||||||
|
else self.inv_freq_short
|
||||||
|
)
|
||||||
|
freqs = position_ids[:, None] * inv_freq[None, :]
|
||||||
|
emb = mx.concatenate([freqs, freqs], axis=-1)
|
||||||
|
cos = mx.cos(emb) * self.scaling_factor
|
||||||
|
sin = mx.sin(emb) * self.scaling_factor
|
||||||
|
return cos, sin
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
def _rotate_half(_x):
|
||||||
|
midpoint = _x.shape[-1] // 2
|
||||||
|
x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
|
||||||
|
return mx.concatenate([-x2, x1], axis=-1)
|
||||||
|
|
||||||
|
cos, sin = self._get_cos_sin(offset, x.shape[2])
|
||||||
|
return (x * cos) + (_rotate_half(x) * sin)
|
Loading…
Reference in New Issue
Block a user