mlx-examples/llms/mlx_lm/models/su_rope.py

61 lines
2.3 KiB
Python
Raw Normal View History

# Copyright © 2023-2024 Apple Inc.
import math
from typing import List, Union
import mlx.core as mx
import mlx.nn as nn
class SuScaledRotaryEmbedding(nn.Module):
def __init__(
self,
dims: int,
base: float = 10000.0,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
short_factor: Union[List[float], float] = 1.0,
long_factor: Union[List[float], float] = 1.0,
):
"""
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
Args:
dims (int): The feature dimensions to be rotated.
base (int, optional): Base for the exponential scaling.
max_position_embeddings (int, optional): The maximum sequence
length that this model was trained with. This is used to determine
the size of the original RoPE embeddings when using long scaling.
Default: ``131072``.
original_max_position_embeddings (int, optional): The maximum
sequence length that this model was trained with. This is used to
determine the size of the original RoPE embeddings when using long
scaling. Default: ``4096``.
short_factor (float or list[float], optional): List of scaling
factors for sequences of length lesser than
``original_max_position_embeddings``. Default: ``1.0``.
long_factor (float or list[float], optional): List of scaling
factors for sequences of length greater than
``original_max_position_embeddings``. Default: ``1.0``.
"""
super().__init__()
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self.original_max_position_embeddings = original_max_position_embeddings
self.scale = math.sqrt(
1
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
self.scale * x,
x.shape[-1],
traditional=False,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)