# Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .rope_utils import initialize_rope @dataclass class ModelArgs(BaseModelArgs): model_type: str hidden_size: int num_hidden_layers: int intermediate_size: int num_attention_heads: int rms_norm_eps: float vocab_size: int head_dim: Optional[int] = None max_position_embeddings: Optional[int] = None num_key_value_heads: Optional[int] = None attention_bias: bool = False mlp_bias: bool = False rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None tie_word_embeddings: bool = True def __post_init__(self): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads self.scale = head_dim**-0.5 if hasattr(args, "attention_bias"): attention_bias = args.attention_bias else: attention_bias = False self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) self.rope = initialize_rope( self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings, ) self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) queries = self.q_norm(queries) keys = self.k_norm(keys) # Prepare the queries, keys and values for the attention computation queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: queries = self.rope(queries, offset=cache.offset) keys = self.rope(keys, offset=cache.offset) keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) output = scaled_dot_product_attention( queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) class MLP(nn.Module): def __init__(self, args: ModelArgs): super().__init__() dim = args.hidden_size hidden_dim = args.intermediate_size if hasattr(args, "mlp_bias"): mlp_bias = args.mlp_bias else: mlp_bias = False self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) def __call__(self, x) -> mx.array: return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.num_attention_heads = args.num_attention_heads self.hidden_size = args.hidden_size self.self_attn = Attention(args) self.mlp = MLP(args) self.post_attention_layernorm = nn.RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) self.post_feedforward_layernorm = nn.RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) self.args = args def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: r = self.post_attention_layernorm(self.self_attn(x, mask, cache)) h = x + r r = self.post_feedforward_layernorm(self.mlp(h)) out = h + r return out class LlamaModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size self.num_hidden_layers = args.num_hidden_layers assert self.vocab_size > 0 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [ TransformerBlock(args=args) for _ in range(args.num_hidden_layers) ] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, inputs: mx.array, cache=None, mask=None, ): h = self.embed_tokens(inputs) if mask is None: mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) for layer, c in zip(self.layers, cache): h = layer(h, mask, cache=c) return self.norm(h) class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.model_type = args.model_type self.model = LlamaModel(args) if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, inputs: mx.array, cache=None, mask=None, ): out = self.model(inputs, cache, mask) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: out = self.lm_head(out) return out def sanitize(self, weights): # Remove unused precomputed rotary freqs return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } @property def layers(self): return self.model.layers