diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py new file mode 100644 index 00000000..eaed5dd8 --- /dev/null +++ b/llms/mlx_lm/models/exaone.py @@ -0,0 +1,163 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_layers: int + intermediate_size: int + num_attention_heads: int + vocab_size: int + rope_theta: float + layer_norm_epsilon: float + num_key_value_heads: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + attention_bias: bool = False + mlp_bias: bool = False + + +class AttentionModule(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim or (dim // n_heads) + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + def __call__( + self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None + ) -> mx.array: + B, L, D = x.shape + q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + q = self.rope(q, offset=cache.offset) + k = self.rope(k, offset=cache.offset) + k, v = cache.update_and_fetch(k, v) + else: + q = self.rope(q) + k = self.rope(k) + + out = scaled_dot_product_attention( + q, k, v, cache=cache, scale=self.scale, mask=mask + ) + out = out.transpose(0, 2, 1, 3).reshape(B, L, D) + return self.out_proj(out) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.attention = AttentionModule(args) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + hidden_dim = args.intermediate_size + self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) + self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) + self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias) + + def __call__(self, x: mx.array) -> mx.array: + return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.attn = Attention(args) + self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.mlp = MLP(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + h = x + self.attn.attention(self.ln_1(x), mask, cache) + out = h + self.mlp(self.ln_2(h)) + return out + + +class ExaoneModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.wte = nn.Embedding(args.vocab_size, args.hidden_size) + self.h = [TransformerBlock(args) for _ in range(args.num_layers)] + self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.wte(inputs) + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.h) + + for layer, c in zip(self.h, cache): + h = layer(h, mask, cache=c) + + return self.ln_f(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.transformer = ExaoneModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.transformer(inputs, cache) + if self.args.tie_word_embeddings: + out = self.transformer.wte.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.transformer.h diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 438278e5..290cb83e 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope @dataclass @@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads - if self.rope_scaling: - if not "factor" in self.rope_scaling: - raise ValueError(f"rope_scaling must contain 'factor'") - rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( - "rope_type" - ) - if rope_type is None: - raise ValueError( - f"rope_scaling must contain either 'type' or 'rope_type'" - ) - if rope_type not in ["linear", "dynamic", "llama3"]: - raise ValueError( - "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'" - ) - - -class DynamicNTKScalingRoPE(nn.Module): - """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE.""" - - def __init__( - self, - dims: int, - max_position_embeddings: int = 2048, - traditional: bool = False, - base: float = 10000, - scale: float = 1.0, - rope_type: str = "default", - rope_scaling: dict = None, - ): - super().__init__() - self.dims = dims - self.max_position_embeddings = max_position_embeddings - self.traditional = traditional - self.scale = scale - self.rope_type = rope_type - self.rope_scaling = rope_scaling - self.base = base - self.compute_freqs() - - def compute_freqs(self): - if self.rope_type != "llama3": - self._freqs = None - return - factor = self.rope_scaling["factor"] - low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) - high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) - old_context_len = self.rope_scaling.get( - "original_max_position_embeddings", - 8192, - ) - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) - wavelens = 2 * mx.pi * freqs - - freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) - is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) - smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) - self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) - self.base = None - - def extra_repr(self): - return ( - f"{self.dims}, traditional={self.traditional}, " - f"max_position_embeddings={self.max_position_embeddings}, " - f"scaling_factor={self.scale}, rope_type={self.rope_type}" - ) - - def __call__(self, x, offset: int = 0): - return mx.fast.rope( - x, - self.dims, - traditional=self.traditional, - base=self.base, - scale=self.scale, - offset=offset, - freqs=self._freqs, - ) - - -def initialize_rope(args: ModelArgs): - head_dim = args.head_dim or args.hidden_size // args.num_attention_heads - - rope_scaling = args.rope_scaling - rope_type = "default" - rope_scale = 1.0 - - if rope_scaling is not None: - rope_type = ( - rope_scaling.get("type") or rope_scaling.get("rope_type") or "default" - ) - if rope_type == "linear": - rope_scale = 1 / rope_scaling["factor"] - elif rope_type == "llama3": - rope_scale = 1.0 # The scaling is handled internally for llama3 - - return DynamicNTKScalingRoPE( - dims=head_dim, - max_position_embeddings=args.max_position_embeddings, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - rope_type=rope_type, - rope_scaling=rope_scaling, - ) - class Attention(nn.Module): def __init__(self, args: ModelArgs): @@ -165,7 +55,13 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - self.rope = initialize_rope(args) + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) def __call__( self, diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py index a28fdcc1..64d7e116 100644 --- a/llms/mlx_lm/models/olmo2.py +++ b/llms/mlx_lm/models/olmo2.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope @dataclass @@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads - if self.rope_scaling: - if not "factor" in self.rope_scaling: - raise ValueError(f"rope_scaling must contain 'factor'") - rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( - "rope_type" - ) - if rope_type is None: - raise ValueError( - f"rope_scaling must contain either 'type' or 'rope_type'" - ) - if rope_type not in ["linear", "dynamic", "llama3"]: - raise ValueError( - "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'" - ) - - -class DynamicNTKScalingRoPE(nn.Module): - """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE.""" - - def __init__( - self, - dims: int, - max_position_embeddings: int = 2048, - traditional: bool = False, - base: float = 10000, - scale: float = 1.0, - rope_type: str = "default", - rope_scaling: dict = None, - ): - super().__init__() - self.dims = dims - self.max_position_embeddings = max_position_embeddings - self.traditional = traditional - self.scale = scale - self.rope_type = rope_type - self.rope_scaling = rope_scaling - self.base = base - self.compute_freqs() - - def compute_freqs(self): - if self.rope_type != "llama3": - self._freqs = None - return - factor = self.rope_scaling["factor"] - low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) - high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) - old_context_len = self.rope_scaling.get( - "original_max_position_embeddings", - 8192, - ) - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) - wavelens = 2 * mx.pi * freqs - - freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) - is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) - smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) - self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) - self.base = None - - def extra_repr(self): - return ( - f"{self.dims}, traditional={self.traditional}, " - f"max_position_embeddings={self.max_position_embeddings}, " - f"scaling_factor={self.scale}, rope_type={self.rope_type}" - ) - - def __call__(self, x, offset: int = 0): - return mx.fast.rope( - x, - self.dims, - traditional=self.traditional, - base=self.base, - scale=self.scale, - offset=offset, - freqs=self._freqs, - ) - - -def initialize_rope(args: ModelArgs): - head_dim = args.head_dim or args.hidden_size // args.num_attention_heads - - rope_scaling = args.rope_scaling - rope_type = "default" - rope_scale = 1.0 - - if rope_scaling is not None: - rope_type = ( - rope_scaling.get("type") or rope_scaling.get("rope_type") or "default" - ) - if rope_type == "linear": - rope_scale = 1 / rope_scaling["factor"] - elif rope_type == "llama3": - rope_scale = 1.0 # The scaling is handled internally for llama3 - - return DynamicNTKScalingRoPE( - dims=head_dim, - max_position_embeddings=args.max_position_embeddings, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - rope_type=rope_type, - rope_scaling=rope_scaling, - ) - class Attention(nn.Module): def __init__(self, args: ModelArgs): @@ -165,7 +55,14 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - self.rope = initialize_rope(args) + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) diff --git a/llms/mlx_lm/models/rope_utils.py b/llms/mlx_lm/models/rope_utils.py new file mode 100644 index 00000000..d30b432d --- /dev/null +++ b/llms/mlx_lm/models/rope_utils.py @@ -0,0 +1,91 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + + +class Llama3RoPE(nn.Module): + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scaling_config: dict = None, + ): + super().__init__() + self.dims = dims + self.max_position_embeddings = max_position_embeddings + self.traditional = traditional + + factor = scaling_config["factor"] + low_freq_factor = scaling_config.get("low_freq_factor", 1.0) + high_freq_factor = scaling_config.get("high_freq_factor", 4.0) + old_context_len = scaling_config.get( + "original_max_position_embeddings", + 8192, + ) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + freqs = base ** (mx.arange(0, dims, 2) / dims) + wavelens = 2 * mx.pi * freqs + + freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + + def extra_repr(self): + return ( + f"{self.dims}, traditional={self.traditional}, " + f"max_position_embeddings={self.max_position_embeddings}" + ) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +def initialize_rope( + dims, + base, + traditional, + scaling_config: Optional[dict] = None, + max_position_embeddings: Optional[int] = None, +): + if scaling_config is not None: + rope_type = scaling_config.get("type") or scaling_config.get( + "rope_type", "default" + ) + else: + rope_type = "default" + + if rope_type in ["default", "linear"]: + scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0 + return nn.RoPE(dims, traditional=traditional, base=base, scale=scale) + + elif rope_type == "llama3": + return Llama3RoPE( + dims=dims, + max_position_embeddings=max_position_embeddings, + traditional=traditional, + base=base, + scaling_config=scaling_config, + ) + else: + raise ValueError(f"Unsupported RoPE type {rope_type}") diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 8351ed1b..6821f434 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -144,6 +144,8 @@ def linear_to_lora_layers( "mixer.out_proj", ] ) + elif model.model_type == "exaone": + keys = set(["attn.attention.q_proj", "attn.attention.v_proj"]) else: raise ValueError(f"Lora does not support {model.model_type}") diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index edb594d7..374a5113 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -2,7 +2,9 @@ import unittest import mlx.core as mx +import mlx.nn as nn from mlx.utils import tree_map +from mlx_lm.models import rope_utils from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache @@ -126,6 +128,26 @@ class TestModels(unittest.TestCase): self.assertEqual(cache.offset, 22) self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def test_rope(self): + rope = rope_utils.initialize_rope(32, base=100, traditional=False) + self.assertTrue(isinstance(rope, nn.RoPE)) + + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "linear", "factor": 10.0}, + ) + self.assertTrue(isinstance(rope, nn.RoPE)) + + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "llama3", "factor": 2.0}, + ) + self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE)) + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers) @@ -812,6 +834,23 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_exaone(self): + from mlx_lm.models import exaone + + args = exaone.ModelArgs( + model_type="exaone", + hidden_size=128, + num_layers=4, + intermediate_size=256, + num_attention_heads=8, + num_key_value_heads=2, + vocab_size=1000, + layer_norm_epsilon=1e-4, + rope_theta=10000, + ) + model = exaone.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers) + if __name__ == "__main__": unittest.main()