From 2fce02acd87193487eafeeec639eb8903cc96083 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 13 Mar 2025 16:14:25 +0100 Subject: [PATCH] Add support for Gemma3 (#1336) * add support for gemma3 * fix model loading * revert rmsnorm * revert is sliding pattern * revert * add tests * formatting * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun * Update llms/mlx_lm/models/gemma3_text.py Co-authored-by: Awni Hannun * fix sliding window mask --------- Co-authored-by: Awni Hannun Co-authored-by: Awni Hannun --- llms/mlx_lm/models/gemma3_text.py | 238 ++++++++++++++++++++++++++++++ llms/tests/test_models.py | 20 +++ 2 files changed, 258 insertions(+) create mode 100644 llms/mlx_lm/models/gemma3_text.py diff --git a/llms/mlx_lm/models/gemma3_text.py b/llms/mlx_lm/models/gemma3_text.py new file mode 100644 index 00000000..be71f461 --- /dev/null +++ b/llms/mlx_lm/models/gemma3_text.py @@ -0,0 +1,238 @@ +# Copyright © 2025 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Optional + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask +from .cache import KVCache, RotatingKVCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int = 1152 + num_hidden_layers: int = 26 + intermediate_size: int = 6912 + num_attention_heads: int = 4 + head_dim: int = 256 + rms_norm_eps: float = 1.0e-6 + vocab_size: int = 262144 + num_key_value_heads: int = 1 + rope_global_base_freq: float = 1_000_000.0 + rope_local_base_freq: float = 10_000.0 + rope_traditional: bool = False + query_pre_attn_scalar: float = 256 + sliding_window: int = 512 + sliding_window_pattern: int = 6 + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.repeats = n_heads // n_kv_heads + self.head_dim = head_dim = args.head_dim + self.layer_idx = layer_idx + + self.scale = args.query_pre_attn_scalar**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps) + self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps) + self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern != 0 + + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=( + args.rope_local_base_freq + if self.is_sliding + else args.rope_global_base_freq + ), + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, _ = x.shape + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # Sliding window + if mask is not None and mask.shape[-1] != keys.shape[-2]: + mask = mask[..., -keys.shape[-2] :] + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def __call__(self, x): + return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args, layer_idx) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.pre_feedforward_layernorm = RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + self.post_attention_layernorm(r) + r = self.mlp(self.pre_feedforward_layernorm(h)) + out = h + self.post_feedforward_layernorm(r) + return out + + +class Gemma3Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args, layer_idx=layer_idx) + for layer_idx in range(args.num_hidden_layers) + ] + self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + + h = self.embed_tokens(inputs) + h *= mx.array(self.args.hidden_size**0.5, mx.bfloat16).astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + if mask is None: + j = self.args.sliding_window_pattern + full_mask = create_attention_mask(h, cache[j - 1 : j]) + sliding_window_mask = create_attention_mask(h, cache) + + for i, (layer, c) in enumerate(zip(self.layers, cache)): + is_sliding = ( + i % self.args.sliding_window_pattern + == self.args.sliding_window_pattern - 1 + ) + + if mask is None and is_sliding: + mask = sliding_window_mask + elif mask is None: + mask = full_mask + + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Gemma3Model(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + mask: Optional[mx.array] = None, + ): + out = self.model(inputs, mask, cache) + out = self.lm_head(out) + return out + + def sanitize(self, weights): + if "lm_head.weight" not in weights: + weights["lm_head.weight"] = weights["model.embed_tokens.weight"] + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers + + def make_cache(self): + caches = [] + for i in range(self.args.num_hidden_layers): + if ( + i % self.args.sliding_window_pattern + == self.args.sliding_window_pattern - 1 + ): + caches.append(KVCache()) + else: + caches.append( + RotatingKVCache(max_size=self.args.sliding_window, keep=0) + ) + return caches diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 0c0fc601..b4e7aab8 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -755,6 +755,26 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_gemma3_text(self): + from mlx_lm.models import gemma3_text + + args = gemma3_text.ModelArgs( + model_type="gemma3_text", + hidden_size=128, + num_hidden_layers=12, + intermediate_size=256, + num_attention_heads=4, + head_dim=32, + rms_norm_eps=1e-4, + num_key_value_heads=1, + sliding_window=1024, + sliding_window_pattern=6, + ) + model = gemma3_text.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_gpt_bigcode(self): from mlx_lm.models import gpt_bigcode