Adds EXAONE architecture. (#1145)

* Adds EXAONE architecture.

* nits + format

* format

* clean up and fix rope

* clean up and fix rope

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
n8programs
2024-12-09 10:58:25 -05:00
committed by GitHub
parent 893b3f085e
commit 5687d5b99b
6 changed files with 312 additions and 224 deletions

View File

@@ -7,6 +7,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
@@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
if not "factor" in self.rope_scaling:
raise ValueError(f"rope_scaling must contain 'factor'")
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
"rope_type"
)
if rope_type is None:
raise ValueError(
f"rope_scaling must contain either 'type' or 'rope_type'"
)
if rope_type not in ["linear", "dynamic", "llama3"]:
raise ValueError(
"rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
rope_type: str = "default",
rope_scaling: dict = None,
):
super().__init__()
self.dims = dims
self.max_position_embeddings = max_position_embeddings
self.traditional = traditional
self.scale = scale
self.rope_type = rope_type
self.rope_scaling = rope_scaling
self.base = base
self.compute_freqs()
def compute_freqs(self):
if self.rope_type != "llama3":
self._freqs = None
return
factor = self.rope_scaling["factor"]
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.rope_scaling.get(
"original_max_position_embeddings",
8192,
)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
wavelens = 2 * mx.pi * freqs
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
self.base = None
def extra_repr(self):
return (
f"{self.dims}, traditional={self.traditional}, "
f"max_position_embeddings={self.max_position_embeddings}, "
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=self.base,
scale=self.scale,
offset=offset,
freqs=self._freqs,
)
def initialize_rope(args: ModelArgs):
head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
rope_scaling = args.rope_scaling
rope_type = "default"
rope_scale = 1.0
if rope_scaling is not None:
rope_type = (
rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
)
if rope_type == "linear":
rope_scale = 1 / rope_scaling["factor"]
elif rope_type == "llama3":
rope_scale = 1.0 # The scaling is handled internally for llama3
return DynamicNTKScalingRoPE(
dims=head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
rope_type=rope_type,
rope_scaling=rope_scaling,
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
@@ -165,7 +55,13 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(args)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self,