# Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn import numpy as np from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass class ModelArgs(BaseModelArgs): model_type: str hidden_size: int dim_model_base: int num_hidden_layers: int intermediate_size: int num_attention_heads: int rms_norm_eps: float vocab_size: int num_key_value_heads: int scale_depth: float scale_emb: float rope_theta: float = 1000000.0 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[str, float]]] = None tie_word_embeddings: bool = False class MLP(nn.Module): def __init__(self, args): super().__init__() self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) def __call__(self, x): return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.hidden_size = args.hidden_size self.num_heads = n_heads = args.num_attention_heads self.rope_theta = args.rope_theta self.head_dim = head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 self.num_key_value_heads = args.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.head_dim, bias=False ) self.k_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False ) self.v_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False ) self.o_proj = nn.Linear( self.num_heads * self.head_dim, self.hidden_size, bias=False ) rope_scale = ( 1 / args.rope_scaling["factor"] if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" else 1 ) self.rope = nn.RoPE( dims=self.head_dim, traditional=args.rope_traditional, base=self.rope_theta, scale=rope_scale, ) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ): B, L, _ = x.shape queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( 0, 2, 1, 3 ) if cache is not None: queries = self.rope(queries, offset=cache.offset) keys = self.rope(keys, offset=cache.offset) keys, values = cache.update_and_fetch(keys, values) else: queries = self.rope(queries) keys = self.rope(keys) attn_output = scaled_dot_product_attention( queries, keys, values, cache=cache, scale=self.scale, mask=mask ) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(attn_output) class DecoderLayer(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.hidden_size = args.hidden_size self.num_hidden_layers = args.num_hidden_layers self.self_attn = Attention(args) self.mlp = MLP(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm( args.hidden_size, eps=args.rms_norm_eps ) self.scale_depth = args.scale_depth self.num_hidden_layers = args.num_hidden_layers def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) r = self.mlp(self.post_attention_layernorm(h)) out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) return out class MiniCPMModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.vocab_size = args.vocab_size assert self.vocab_size > 0 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( self, inputs: mx.array, cache=None, ): h = self.embed_tokens(inputs) * self.args.scale_emb mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) for layer, c in zip(self.layers, cache): h = layer(h, mask, c) return self.norm(h) class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.model_type = args.model_type self.model = MiniCPMModel(args) if not self.args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__( self, inputs: mx.array, cache=None, ): out = self.model(inputs, cache) if not self.args.tie_word_embeddings: out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) else: out = out @ self.model.embed_tokens.weight.T return out def sanitize(self, weights): if "lm_head.weight" not in weights: weights["lm_head.weight"] = weights["model.embed_tokens.weight"] return weights @property def layers(self): return self.model.layers