Su-RoPE(Rotary Position Embedding) for Phi-3 (#813)

* Su-RoPE

* nits

* Update su_rope.py

* Update su_rope.py

Per GPT4: "The error TypeError: 'type' object is not subscriptable is caused by using the type hint list[float] in a version of Python that does not support it. This syntax is only available in Python 3.9 and later."

* Ran isort

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
JosefAlbers
2024-06-11 22:20:04 +09:00
committed by GitHub
parent a54dfd698e
commit fda41545a6
2 changed files with 105 additions and 13 deletions

View File

@@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .su_rope import SuScaledRotaryEmbedding
@dataclass
@@ -20,6 +21,8 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
def __post_init__(self):
if self.num_key_value_heads is None:
@@ -30,9 +33,9 @@ class ModelArgs(BaseModelArgs):
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
if self.rope_scaling["type"] not in ["su", "linear"]:
print(
"[WARNING] rope_scaling 'type' currently only supports 'linear' setting rope scaling to false."
"[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false."
)
self.rope_scaling = None
@@ -53,17 +56,27 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "su":
self.rope = SuScaledRotaryEmbedding(
head_dim,
traditional=False,
base=args.rope_theta,
scale=rope_scale,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
if args.rope_scaling and args.rope_scaling["type"] == "linear":
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,