mlx-examples/llms/mlx_lm/models/gemma3_text.py
Prince Canuma 9eb59b7250
Update llms/mlx_lm/models/gemma3_text.py
Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
2025-03-12 14:59:56 +01:00

242 lines
7.6 KiB
Python

import inspect
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
from .cache import KVCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 1152
num_hidden_layers: int = 26
intermediate_size: int = 6912
num_attention_heads: int = 4
head_dim: int = 256
rms_norm_eps: float = 1.0e-6
vocab_size: int = 262144
num_key_value_heads: int = 1
rope_global_base_freq: float = 1_000_000.0
rope_local_base_freq: float = 10_000.0
rope_traditional: bool = False
query_pre_attn_scalar: float = 256
sliding_window: int = 512
mm_tokens_per_image: int = 256
sliding_window_pattern: int = 6
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
self.head_dim = head_dim = args.head_dim
self.layer_idx = layer_idx
self.scale = args.query_pre_attn_scalar**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern != 0
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=(
args.rope_local_base_freq
if self.is_sliding
else args.rope_global_base_freq
),
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Sliding window
if self.is_sliding and mask is not None:
key_len = keys.shape[-2]
if mask.shape[-1] != key_len:
mask = mask[..., -key_len:]
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
# This should not be GELU approx, jax.nn.gelu
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
class Gemma3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=layer_idx)
for layer_idx in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs)
h *= self.args.hidden_size**0.5 # persistent precision issue in scaling
if cache is None:
cache = [None] * len(self.layers)
if mask is None:
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Gemma3Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, mask, cache)
out = self.lm_head(out)
return out
def sanitize(self, weights):
if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers
def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if (
i % self.args.sliding_window_pattern
== self.args.sliding_window_pattern - 1
):
caches.append(KVCache())
else:
caches.append(
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
)
return caches