mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00

* Adds EXAONE architecture. * nits + format * format * clean up and fix rope * clean up and fix rope --------- Co-authored-by: Awni Hannun <awni@apple.com>
92 lines
2.7 KiB
Python
92 lines
2.7 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
from typing import Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
|
|
class Llama3RoPE(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dims: int,
|
|
max_position_embeddings: int = 2048,
|
|
traditional: bool = False,
|
|
base: float = 10000,
|
|
scaling_config: dict = None,
|
|
):
|
|
super().__init__()
|
|
self.dims = dims
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.traditional = traditional
|
|
|
|
factor = scaling_config["factor"]
|
|
low_freq_factor = scaling_config.get("low_freq_factor", 1.0)
|
|
high_freq_factor = scaling_config.get("high_freq_factor", 4.0)
|
|
old_context_len = scaling_config.get(
|
|
"original_max_position_embeddings",
|
|
8192,
|
|
)
|
|
|
|
low_freq_wavelen = old_context_len / low_freq_factor
|
|
high_freq_wavelen = old_context_len / high_freq_factor
|
|
|
|
freqs = base ** (mx.arange(0, dims, 2) / dims)
|
|
wavelens = 2 * mx.pi * freqs
|
|
|
|
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
|
|
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
|
|
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
|
|
high_freq_factor - low_freq_factor
|
|
)
|
|
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
|
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
|
|
|
def extra_repr(self):
|
|
return (
|
|
f"{self.dims}, traditional={self.traditional}, "
|
|
f"max_position_embeddings={self.max_position_embeddings}"
|
|
)
|
|
|
|
def __call__(self, x, offset: int = 0):
|
|
return mx.fast.rope(
|
|
x,
|
|
self.dims,
|
|
traditional=self.traditional,
|
|
base=None,
|
|
scale=1.0,
|
|
offset=offset,
|
|
freqs=self._freqs,
|
|
)
|
|
|
|
|
|
def initialize_rope(
|
|
dims,
|
|
base,
|
|
traditional,
|
|
scaling_config: Optional[dict] = None,
|
|
max_position_embeddings: Optional[int] = None,
|
|
):
|
|
if scaling_config is not None:
|
|
rope_type = scaling_config.get("type") or scaling_config.get(
|
|
"rope_type", "default"
|
|
)
|
|
else:
|
|
rope_type = "default"
|
|
|
|
if rope_type in ["default", "linear"]:
|
|
scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0
|
|
return nn.RoPE(dims, traditional=traditional, base=base, scale=scale)
|
|
|
|
elif rope_type == "llama3":
|
|
return Llama3RoPE(
|
|
dims=dims,
|
|
max_position_embeddings=max_position_embeddings,
|
|
traditional=traditional,
|
|
base=base,
|
|
scaling_config=scaling_config,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported RoPE type {rope_type}")
|