# Copyright © 2023-2024 Apple Inc. from typing import Optional import mlx.core as mx import mlx.nn as nn class Llama3RoPE(nn.Module): def __init__( self, dims: int, max_position_embeddings: int = 2048, traditional: bool = False, base: float = 10000, scaling_config: dict = None, ): super().__init__() self.dims = dims self.max_position_embeddings = max_position_embeddings self.traditional = traditional factor = scaling_config["factor"] low_freq_factor = scaling_config.get("low_freq_factor", 1.0) high_freq_factor = scaling_config.get("high_freq_factor", 4.0) old_context_len = scaling_config.get( "original_max_position_embeddings", 8192, ) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor freqs = base ** (mx.arange(0, dims, 2) / dims) wavelens = 2 * mx.pi * freqs freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( high_freq_factor - low_freq_factor ) smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) def extra_repr(self): return ( f"{self.dims}, traditional={self.traditional}, " f"max_position_embeddings={self.max_position_embeddings}" ) def __call__(self, x, offset: int = 0): return mx.fast.rope( x, self.dims, traditional=self.traditional, base=None, scale=1.0, offset=offset, freqs=self._freqs, ) def initialize_rope( dims, base, traditional, scaling_config: Optional[dict] = None, max_position_embeddings: Optional[int] = None, ): if scaling_config is not None: rope_type = scaling_config.get("type") or scaling_config.get( "rope_type", "default" ) else: rope_type = "default" if rope_type in ["default", "linear"]: scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0 return nn.RoPE(dims, traditional=traditional, base=base, scale=scale) elif rope_type == "llama3": return Llama3RoPE( dims=dims, max_position_embeddings=max_position_embeddings, traditional=traditional, base=base, scaling_config=scaling_config, ) else: raise ValueError(f"Unsupported RoPE type {rope_type}")