2024-08-17 06:28:39 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-06-11 21:20:04 +08:00
|
|
|
import math
|
|
|
|
from typing import List, Union
|
|
|
|
|
|
|
|
import mlx.core as mx
|
2024-08-24 04:18:51 +08:00
|
|
|
import mlx.nn as nn
|
2024-06-11 21:20:04 +08:00
|
|
|
|
|
|
|
|
2024-08-24 04:18:51 +08:00
|
|
|
class SuScaledRotaryEmbedding(nn.Module):
|
2024-06-11 21:20:04 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
dims: int,
|
|
|
|
base: float = 10000.0,
|
|
|
|
max_position_embeddings: int = 131072,
|
|
|
|
original_max_position_embeddings: int = 4096,
|
|
|
|
short_factor: Union[List[float], float] = 1.0,
|
|
|
|
long_factor: Union[List[float], float] = 1.0,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dims (int): The feature dimensions to be rotated.
|
|
|
|
base (int, optional): Base for the exponential scaling.
|
|
|
|
max_position_embeddings (int, optional): The maximum sequence
|
|
|
|
length that this model was trained with. This is used to determine
|
|
|
|
the size of the original RoPE embeddings when using long scaling.
|
|
|
|
Default: ``131072``.
|
|
|
|
original_max_position_embeddings (int, optional): The maximum
|
|
|
|
sequence length that this model was trained with. This is used to
|
|
|
|
determine the size of the original RoPE embeddings when using long
|
|
|
|
scaling. Default: ``4096``.
|
|
|
|
short_factor (float or list[float], optional): List of scaling
|
|
|
|
factors for sequences of length lesser than
|
|
|
|
``original_max_position_embeddings``. Default: ``1.0``.
|
|
|
|
long_factor (float or list[float], optional): List of scaling
|
|
|
|
factors for sequences of length greater than
|
|
|
|
``original_max_position_embeddings``. Default: ``1.0``.
|
|
|
|
"""
|
2024-08-24 04:18:51 +08:00
|
|
|
super().__init__()
|
|
|
|
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
|
|
|
|
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
|
2024-06-11 21:20:04 +08:00
|
|
|
self.original_max_position_embeddings = original_max_position_embeddings
|
2024-08-24 04:18:51 +08:00
|
|
|
self.scale = math.sqrt(
|
2024-06-11 21:20:04 +08:00
|
|
|
1
|
|
|
|
+ math.log(max_position_embeddings / original_max_position_embeddings)
|
|
|
|
/ math.log(original_max_position_embeddings)
|
|
|
|
)
|
|
|
|
|
|
|
|
def __call__(self, x, offset: int = 0):
|
2024-08-24 04:18:51 +08:00
|
|
|
return mx.fast.rope(
|
|
|
|
self.scale * x,
|
|
|
|
x.shape[-1],
|
|
|
|
traditional=False,
|
|
|
|
base=None,
|
|
|
|
scale=1.0,
|
|
|
|
offset=offset,
|
|
|
|
freqs=self._freqs,
|
|
|
|
)
|