diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index f8facdb1..112ade7d 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -59,19 +59,17 @@ class Attention(nn.Module): self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - rope_scale = 1.0 if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: self.rope = SuScaledRotaryEmbedding( head_dim, - traditional=False, base=args.rope_theta, - scale=rope_scale, max_position_embeddings=args.max_position_embeddings, original_max_position_embeddings=args.original_max_position_embeddings, short_factor=args.rope_scaling["short_factor"], long_factor=args.rope_scaling["long_factor"], ) else: + rope_scale = 1.0 if args.rope_scaling and args.rope_scaling["type"] == "linear": assert isinstance(args.rope_scaling["factor"], float) rope_scale = 1 / args.rope_scaling["factor"] diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py index 0efa5a0c..c75e9610 100644 --- a/llms/mlx_lm/models/su_rope.py +++ b/llms/mlx_lm/models/su_rope.py @@ -11,9 +11,7 @@ class SuScaledRotaryEmbedding(nn.Module): def __init__( self, dims: int, - traditional: bool = False, base: float = 10000.0, - scale: float = 1.0, max_position_embeddings: int = 131072, original_max_position_embeddings: int = 4096, short_factor: Union[List[float], float] = 1.0, @@ -24,10 +22,7 @@ class SuScaledRotaryEmbedding(nn.Module): Args: dims (int): The feature dimensions to be rotated. - traditional (bool, optional): Unused. Default: ``False``. base (int, optional): Base for the exponential scaling. - scale (float, optional): The scale used to scale the positions. - Default: ``1.0``. max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. @@ -44,14 +39,9 @@ class SuScaledRotaryEmbedding(nn.Module): ``original_max_position_embeddings``. Default: ``1.0``. """ super().__init__() - self._short_freqs = mx.array(short_factor, dtype=mx.float32) * base ** ( - mx.arange(0, dims, 2, dtype=mx.float32) / dims - ) - self._long_freqs = ( - scale - * mx.array(long_factor, dtype=mx.float32) - * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) - ) + freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + self._short_freqs = mx.array(short_factor, dtype=mx.float32) * freqs + self._long_freqs = mx.array(long_factor, dtype=mx.float32) * freqs self.original_max_position_embeddings = original_max_position_embeddings self.scale = math.sqrt( 1 @@ -66,11 +56,11 @@ class SuScaledRotaryEmbedding(nn.Module): else self._short_freqs ) return mx.fast.rope( - x, + self.scale * x, x.shape[-1], traditional=False, base=None, - scale=self.scale, + scale=1.0, offset=offset, freqs=freqs, )