# Copyright © 2024 Apple Inc. from dataclasses import dataclass from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .rope_utils import initialize_rope @dataclass class ModelArgs(BaseModelArgs): model_type: str hidden_size: int num_layers: int intermediate_size: int num_attention_heads: int vocab_size: int rope_theta: float layer_norm_epsilon: float num_key_value_heads: int head_dim: Optional[int] = None max_position_embeddings: Optional[int] = None rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None tie_word_embeddings: bool = True attention_bias: bool = False mlp_bias: bool = False class AttentionModule(nn.Module): def __init__(self, args: ModelArgs): super().__init__() dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.head_dim = head_dim = args.head_dim or (dim // n_heads) self.scale = head_dim**-0.5 self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) self.rope = initialize_rope( self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings, ) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None ) -> mx.array: B, L, D = x.shape q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: q = self.rope(q, offset=cache.offset) k = self.rope(k, offset=cache.offset) k, v = cache.update_and_fetch(k, v) else: q = self.rope(q) k = self.rope(k) out = scaled_dot_product_attention( q, k, v, cache=cache, scale=self.scale, mask=mask ) out = out.transpose(0, 2, 1, 3).reshape(B, L, D) return self.out_proj(out) class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.attention = AttentionModule(args) class MLP(nn.Module): def __init__(self, args: ModelArgs): super().__init__() dim = args.hidden_size hidden_dim = args.intermediate_size self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias) def __call__(self, x: mx.array) -> mx.array: return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x)) class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.attn = Attention(args) self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.mlp = MLP(args) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: h = x + self.attn.attention(self.ln_1(x), mask, cache) out = h + self.mlp(self.ln_2(h)) return out class ExaoneModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.wte = nn.Embedding(args.vocab_size, args.hidden_size) self.h = [TransformerBlock(args) for _ in range(args.num_layers)] self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) def __call__( self, inputs: mx.array, mask: mx.array = None, cache=None, ): h = self.wte(inputs) if mask is None: mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.h) for layer, c in zip(self.h, cache): h = layer(h, mask, cache=c) return self.ln_f(h) class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.model_type = args.model_type self.transformer = ExaoneModel(args) if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, inputs: mx.array, mask: mx.array = None, cache=None, ): out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: out = self.lm_head(out) return out @property def layers(self): return self.transformer.h