diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py new file mode 100644 index 00000000..6352ab93 --- /dev/null +++ b/llms/mlx_lm/models/olmoe.py @@ -0,0 +1,236 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_experts: int + num_experts_per_tok: int + norm_topk_prob: bool = False + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + num_key_value_heads: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) + self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=args.mlp_bias) + self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.mlp_bias) + self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=args.mlp_bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class OlmoeSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + self.norm_topk_prob = args.norm_topk_prob + self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False) + self.experts = [MLP(args) for _ in range(self.num_experts)] + + def __call__(self, x: mx.array) -> mx.array: + batch_size, sequence_length, hidden_dim = x.shape + x = x.reshape(-1, hidden_dim) + + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(x) + + # Compute routing weights with softmax + routing_weights = mx.softmax(router_logits, axis=1, precise=True) + + # Initialize output tensor + final_hidden_states = mx.zeros_like(x) + + # Process each token through all experts, weighted by routing weights + for expert_idx in range(self.num_experts): + # Get the weight for this expert for all tokens + expert_weights = routing_weights[:, expert_idx:expert_idx+1] + + # Only process if any weight is significant + if mx.max(expert_weights) > 1e-5: + # Apply expert to all tokens + expert_output = self.experts[expert_idx](x) + + # Weight the output and add to final result + final_hidden_states += expert_output * expert_weights + + return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.self_attn = Attention(args) + self.mlp = OlmoeSparseMoeBlock(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + x = x + self.self_attn(self.input_layernorm(x), mask, cache) + x = x + self.mlp(self.post_attention_layernorm(x)) + return x + + +class OlmoeModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + mask=None, + ): + h = self.embed_tokens(inputs) + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = OlmoeModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + mask=None, + ): + out = self.model(inputs, cache, mask) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers