mlx-examples/llms/mlx_lm/models/gemma3_text.py

246 lines
7.8 KiB
Python
Raw Normal View History

2025-03-12 16:12:23 +08:00
import inspect
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .cache import KVCache, RotatingKVCache
from .base import BaseModelArgs, create_attention_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
2025-03-12 16:37:17 +08:00
hidden_size: int =1152
num_hidden_layers: int = 26
intermediate_size: int = 6912
num_attention_heads: int = 4
2025-03-12 16:12:23 +08:00
head_dim: int = 256
rms_norm_eps: float = 1.0e-6
2025-03-12 16:37:17 +08:00
vocab_size: int = 262144
num_key_value_heads: int = 1
2025-03-12 16:12:23 +08:00
rope_global_base_freq: float = 1_000_000.0
rope_local_base_freq: float = 10_000.0
rope_traditional: bool = False
2025-03-12 16:37:17 +08:00
query_pre_attn_scalar: float = 256
sliding_window: int = 512
2025-03-12 16:12:23 +08:00
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
mm_tokens_per_image: int = 256
sliding_window_pattern: int = 6
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
self.head_dim = head_dim = args.head_dim
self.layer_idx = layer_idx
self.scale = args.query_pre_attn_scalar**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
2025-03-12 16:37:17 +08:00
self.is_sliding = (layer_idx + 1) % args.sliding_window_pattern == 0
2025-03-12 16:12:23 +08:00
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=(
args.rope_local_base_freq
if self.is_sliding
else args.rope_global_base_freq
),
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Sliding window
if self.is_sliding and mask is not None:
key_len = keys.shape[-2]
if mask.shape[-1] != key_len:
mask = mask[..., :key_len]
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
# This should not be GELU approx, jax.nn.gelu
2025-03-12 16:37:17 +08:00
return self.down_proj(nn.gelu_fast_approx(self.gate_proj(x)) * self.up_proj(x))
2025-03-12 16:12:23 +08:00
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.pre_feedforward_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
class Gemma3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=layer_idx)
for layer_idx in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
2025-03-12 16:37:17 +08:00
h = self.embed_tokens(inputs)
2025-03-12 16:12:23 +08:00
h *= self.args.hidden_size**0.5 # persistent precision issue in scaling
if cache is None:
cache = [None] * len(self.layers)
2025-03-12 16:37:17 +08:00
if mask is None:
# Sliding window
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
2025-03-12 16:12:23 +08:00
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
2025-03-12 16:37:17 +08:00
def __init__(self, args: ModelArgs):
2025-03-12 16:12:23 +08:00
super().__init__()
2025-03-12 16:37:17 +08:00
self.args = args
self.model_type = args.model_type
self.model = Gemma3Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
2025-03-12 16:12:23 +08:00
def __call__(
self,
inputs: mx.array,
cache=None,
mask: Optional[mx.array] = None,
):
2025-03-12 16:37:17 +08:00
out = self.model(inputs, mask, cache)
2025-03-12 16:12:23 +08:00
out = self.lm_head(out)
return out
def sanitize(self, weights):
if "lm_head.weight" not in weights:
2025-03-12 16:37:17 +08:00
weights["lm_head.weight"] = weights[
"model.embed_tokens.weight"
2025-03-12 16:12:23 +08:00
]
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads
def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if (
i % self.args.sliding_window_pattern
== self.args.sliding_window_pattern - 1
):
caches.append(
KVCache()
)
else:
caches.append(
2025-03-12 16:37:17 +08:00
RotatingKVCache(max_size=self.args.sliding_window, keep=0)
2025-03-12 16:12:23 +08:00
)
return caches