diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index e4a8cc7d..dd2d6d82 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -33,9 +33,9 @@ class ModelArgs(BaseModelArgs): if not all(key in self.rope_scaling for key in required_keys): raise ValueError(f"rope_scaling must contain keys {required_keys}") - if self.rope_scaling["type"] not in ["su", "linear"]: + if self.rope_scaling["type"] not in ["longrope", "su", "linear"]: print( - "[WARNING] rope_scaling 'type' currently only supports 'linear' and 'su'; setting rope scaling to false." + "[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false." ) self.rope_scaling = None @@ -58,7 +58,7 @@ class Attention(nn.Module): self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) rope_scale = 1.0 - if args.rope_scaling and args.rope_scaling["type"] == "su": + if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: self.rope = SuScaledRotaryEmbedding( head_dim, traditional=False,