From fda41545a6d85f951a6967a1002e8bef1e9f436b Mon Sep 17 00:00:00 2001 From: JosefAlbers <146810011+JosefAlbers@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:20:04 +0900 Subject: [PATCH] Su-RoPE(Rotary Position Embedding) for Phi-3 (#813) * Su-RoPE * nits * Update su_rope.py * Update su_rope.py Per GPT4: "The error TypeError: 'type' object is not subscriptable is caused by using the type hint list[float] in a version of Python that does not support it. This syntax is only available in Python 3.9 and later." * Ran isort --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/models/phi3.py | 39 +++++++++++------ llms/mlx_lm/models/su_rope.py | 79 +++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 13 deletions(-) create mode 100644 llms/mlx_lm/models/su_rope.py diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 3282dff2..b30456fd 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .su_rope import SuScaledRotaryEmbedding @dataclass @@ -20,6 +21,8 @@ class ModelArgs(BaseModelArgs): rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None + max_position_embeddings: int = 131072 + original_max_position_embeddings: int = 4096 def __post_init__(self): if self.num_key_value_heads is None: @@ -30,9 +33,9 @@ class ModelArgs(BaseModelArgs): if not all(key in self.rope_scaling for key in required_keys): raise ValueError(f"rope_scaling must contain keys {required_keys}") - if self.rope_scaling["type"] != "linear": + if self.rope_scaling["type"] not in ["su", "linear"]: print( - "[WARNING] rope_scaling 'type' currently only supports 'linear' setting rope scaling to false." + "[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false." ) self.rope_scaling = None @@ -53,17 +56,27 @@ class Attention(nn.Module): self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) + rope_scale = 1.0 + if args.rope_scaling and args.rope_scaling["type"] == "su": + self.rope = SuScaledRotaryEmbedding( + head_dim, + traditional=False, + base=args.rope_theta, + scale=rope_scale, + max_position_embeddings=args.max_position_embeddings, + original_max_position_embeddings=args.original_max_position_embeddings, + short_factor=args.rope_scaling["short_factor"], + long_factor=args.rope_scaling["long_factor"], + ) + else: + if args.rope_scaling and args.rope_scaling["type"] == "linear": + rope_scale = 1 / args.rope_scaling["factor"] + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) def __call__( self, diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py new file mode 100644 index 00000000..cdf6ceaf --- /dev/null +++ b/llms/mlx_lm/models/su_rope.py @@ -0,0 +1,79 @@ +import math +from typing import List, Union + +import mlx.core as mx + + +class SuScaledRotaryEmbedding: + def __init__( + self, + dims: int, + traditional: bool = False, + base: float = 10000.0, + scale: float = 1.0, + max_position_embeddings: int = 131072, + original_max_position_embeddings: int = 4096, + short_factor: Union[List[float], float] = 1.0, + long_factor: Union[List[float], float] = 1.0, + ): + """ + Phi3Su Scaled Rotary Embedding layer for Phi-3 models. + + Args: + dims (int): The feature dimensions to be rotated. + traditional (bool, optional): Unused. Default: ``False``. + base (int, optional): Base for the exponential scaling. + scale (float, optional): The scale used to scale the positions. + Default: ``1.0``. + max_position_embeddings (int, optional): The maximum sequence + length that this model was trained with. This is used to determine + the size of the original RoPE embeddings when using long scaling. + Default: ``131072``. + original_max_position_embeddings (int, optional): The maximum + sequence length that this model was trained with. This is used to + determine the size of the original RoPE embeddings when using long + scaling. Default: ``4096``. + short_factor (float or list[float], optional): List of scaling + factors for sequences of length lesser than + ``original_max_position_embeddings``. Default: ``1.0``. + long_factor (float or list[float], optional): List of scaling + factors for sequences of length greater than + ``original_max_position_embeddings``. Default: ``1.0``. + """ + self.inv_freq_short = 1.0 / ( + mx.array(short_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.inv_freq_long = 1.0 / ( + scale + * mx.array(long_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.original_max_position_embeddings = original_max_position_embeddings + self.scaling_factor = math.sqrt( + 1 + + math.log(max_position_embeddings / original_max_position_embeddings) + / math.log(original_max_position_embeddings) + ) + + def _get_cos_sin(self, offset, L): + position_ids = mx.arange(offset, offset + L, dtype=mx.float32) + inv_freq = ( + self.inv_freq_long + if (offset + L) > self.original_max_position_embeddings + else self.inv_freq_short + ) + freqs = position_ids[:, None] * inv_freq[None, :] + emb = mx.concatenate([freqs, freqs], axis=-1) + cos = mx.cos(emb) * self.scaling_factor + sin = mx.sin(emb) * self.scaling_factor + return cos, sin + + def __call__(self, x, offset: int = 0): + def _rotate_half(_x): + midpoint = _x.shape[-1] // 2 + x1, x2 = _x[..., :midpoint], _x[..., midpoint:] + return mx.concatenate([-x2, x1], axis=-1) + + cos, sin = self._get_cos_sin(offset, x.shape[2]) + return (x * cos) + (_rotate_half(x) * sin)