mlx-examples/llms/mlx_lm/models/gemma3.py
2025-03-12 09:12:23 +01:00

250 lines
7.9 KiB
Python

import inspect
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .cache import KVCache, RotatingKVCache
from .base import BaseModelArgs, create_attention_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int = 8
head_dim: int = 256
rms_norm_eps: float = 1.0e-6
vocab_size: int = 262208
num_key_value_heads: int = 4
rope_global_base_freq: float = 1_000_000.0
rope_local_base_freq: float = 10_000.0
rope_traditional: bool = False
query_pre_attn_scalar: float = 0.0625
sliding_window: int = 1024
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
mm_tokens_per_image: int = 256
sliding_window_pattern: int = 6
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
self.head_dim = head_dim = args.head_dim
self.layer_idx = layer_idx
self.scale = args.query_pre_attn_scalar**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern != 0
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=(
args.rope_local_base_freq
if self.is_sliding
else args.rope_global_base_freq
),
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Sliding window
if self.is_sliding and mask is not None:
key_len = keys.shape[-2]
if mask.shape[-1] != key_len:
mask = mask[..., :key_len]
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
# This should not be GELU approx, jax.nn.gelu
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.pre_feedforward_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
class Gemma3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=layer_idx)
for layer_idx in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
inputs_embeds: mx.array = None,
mask: mx.array = None,
cache=None,
):
if inputs_embeds is None:
h = self.embed_tokens(inputs)
else:
h = inputs_embeds
h *= self.args.hidden_size**0.5 # persistent precision issue in scaling
if cache is None:
cache = [None] * len(self.layers)
# Sliding window
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.model_type = config.model_type
self.model = Gemma3Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
mask: Optional[mx.array] = None,
):
out = self.model(inputs, inputs_embeds, mask, cache)
out = self.lm_head(out)
return out
def sanitize(self, weights):
if "lm_head.weight" not in weights:
weights["language_model.lm_head.weight"] = weights[
"language_model.model.embed_tokens.weight"
]
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads
def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if (
i % self.args.sliding_window_pattern
== self.args.sliding_window_pattern - 1
):
caches.append(
KVCache()
)
else:
caches.append(
RotatingKVCache()
)
return caches