mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-11-01 03:28:08 +08:00 
			
		
		
		
	Add Phi-3.5-MoE (#946)
* add phimoe * add phimoe to tunner * add switch_mlp * fix SuScaled args * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		
							
								
								
									
										218
									
								
								llms/mlx_lm/models/phimoe.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										218
									
								
								llms/mlx_lm/models/phimoe.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,218 @@ | ||||
| # Copyright © 2024 Apple Inc. | ||||
| import math | ||||
| from dataclasses import dataclass | ||||
| from typing import Dict, List, Optional, Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
|  | ||||
| from .base import BaseModelArgs, create_attention_mask | ||||
| from .su_rope import SuScaledRotaryEmbedding | ||||
| from .switch_layers import SwitchGLU | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class ModelArgs(BaseModelArgs): | ||||
|     model_type: str = "phimoe" | ||||
|     vocab_size: int = 32064 | ||||
|     hidden_size: int = 4096 | ||||
|     intermediate_size: int = 6400 | ||||
|     num_hidden_layers: int = 32 | ||||
|     num_attention_heads: int = 32 | ||||
|     num_key_value_heads: int = 8 | ||||
|     max_position_embeddings: int = 131072 | ||||
|     original_max_position_embeddings: int = 4096 | ||||
|     rms_norm_eps: float = 1e-6 | ||||
|     rope_scaling: Dict[str, Union[float, List[float]]] = None | ||||
|     num_local_experts: int = 16 | ||||
|     num_experts_per_tok: int = 2 | ||||
|     rope_theta: float = 10000.0 | ||||
|  | ||||
|  | ||||
| class Attention(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
|  | ||||
|         dim = args.hidden_size | ||||
|         self.n_heads = n_heads = args.num_attention_heads | ||||
|         self.n_kv_heads = n_kv_heads = args.num_key_value_heads | ||||
|  | ||||
|         head_dim = args.hidden_size // n_heads | ||||
|         self.scale = head_dim**-0.5 | ||||
|  | ||||
|         self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True) | ||||
|         self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) | ||||
|         self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True) | ||||
|         self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True) | ||||
|  | ||||
|         self.rope = SuScaledRotaryEmbedding( | ||||
|             head_dim, | ||||
|             base=args.rope_theta, | ||||
|             max_position_embeddings=args.max_position_embeddings, | ||||
|             original_max_position_embeddings=args.original_max_position_embeddings, | ||||
|             short_factor=args.rope_scaling["short_factor"], | ||||
|             long_factor=args.rope_scaling["long_factor"], | ||||
|             short_mscale=args.rope_scaling["short_mscale"], | ||||
|             long_mscale=args.rope_scaling["long_mscale"], | ||||
|         ) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         x: mx.array, | ||||
|         mask: Optional[mx.array] = None, | ||||
|         cache=None, | ||||
|     ) -> mx.array: | ||||
|         B, L, D = x.shape | ||||
|  | ||||
|         queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) | ||||
|  | ||||
|         # Prepare the queries, keys and values for the attention computation | ||||
|         queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | ||||
|         keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||||
|         values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||||
|  | ||||
|         if cache is not None: | ||||
|             queries = self.rope(queries, offset=cache.offset) | ||||
|             keys = self.rope(keys, offset=cache.offset) | ||||
|             keys, values = cache.update_and_fetch(keys, values) | ||||
|         else: | ||||
|             queries = self.rope(queries) | ||||
|             keys = self.rope(keys) | ||||
|  | ||||
|         output = mx.fast.scaled_dot_product_attention( | ||||
|             queries, keys, values, scale=self.scale, mask=mask | ||||
|         ) | ||||
|         output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) | ||||
|         return self.o_proj(output) | ||||
|  | ||||
|  | ||||
| class PhiMoESparseMoeBlock(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
|         self.hidden_dim = args.hidden_size | ||||
|         self.ffn_dim = args.intermediate_size | ||||
|         self.num_experts = args.num_local_experts | ||||
|         self.top_k = args.num_experts_per_tok | ||||
|  | ||||
|         self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) | ||||
|         self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts) | ||||
|  | ||||
|     def __call__(self, x: mx.array) -> mx.array: | ||||
|         gates = self.gate(x) | ||||
|  | ||||
|         k = self.top_k | ||||
|         inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) | ||||
|         scores = mx.take_along_axis(gates, inds, axis=-1) | ||||
|         scores = mx.softmax(scores, axis=-1, precise=True) | ||||
|  | ||||
|         y = self.switch_mlp(x, inds) | ||||
|         y = (y * scores[..., None]).sum(axis=-2) | ||||
|  | ||||
|         return y | ||||
|  | ||||
|  | ||||
| class PhiMoEDecoderLayer(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
|         self.hidden_size = args.hidden_size | ||||
|  | ||||
|         self.self_attn = Attention(args) | ||||
|         self.block_sparse_moe = PhiMoESparseMoeBlock(args) | ||||
|         self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps) | ||||
|         self.post_attention_layernorm = nn.LayerNorm( | ||||
|             args.hidden_size, eps=args.rms_norm_eps | ||||
|         ) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         x: mx.array, | ||||
|         mask: Optional[mx.array] = None, | ||||
|         cache=None, | ||||
|     ) -> mx.array: | ||||
|         residual = x | ||||
|         hidden_states = self.input_layernorm(x) | ||||
|         hidden_states = self.self_attn(hidden_states, mask=mask, cache=cache) | ||||
|         hidden_states = residual + hidden_states | ||||
|  | ||||
|         residual = hidden_states | ||||
|         hidden_states = self.post_attention_layernorm(hidden_states) | ||||
|         hidden_states = self.block_sparse_moe(hidden_states) | ||||
|         hidden_states = residual + hidden_states | ||||
|  | ||||
|         return hidden_states | ||||
|  | ||||
|  | ||||
| class PhiMoEModel(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
|         self.args = args | ||||
|         self.vocab_size = args.vocab_size | ||||
|  | ||||
|         self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) | ||||
|         self.layers = [PhiMoEDecoderLayer(args) for _ in range(args.num_hidden_layers)] | ||||
|         self.norm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         inputs: mx.array, | ||||
|         cache=None, | ||||
|     ) -> mx.array: | ||||
|         h = self.embed_tokens(inputs) | ||||
|  | ||||
|         mask = create_attention_mask(h, cache) | ||||
|  | ||||
|         if cache is None: | ||||
|             cache = [None] * len(self.layers) | ||||
|  | ||||
|         for layer, c in zip(self.layers, cache): | ||||
|             h = layer(h, mask, c) | ||||
|  | ||||
|         return self.norm(h) | ||||
|  | ||||
|  | ||||
| class Model(nn.Module): | ||||
|     def __init__(self, args: ModelArgs): | ||||
|         super().__init__() | ||||
|         self.args = args | ||||
|         self.model = PhiMoEModel(args) | ||||
|         self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         inputs: mx.array, | ||||
|         cache=None, | ||||
|     ): | ||||
|         out = self.model(inputs, cache) | ||||
|         return self.lm_head(out) | ||||
|  | ||||
|     def sanitize(self, weights): | ||||
|         if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights: | ||||
|             return weights | ||||
|         for l in range(self.args.num_hidden_layers): | ||||
|             prefix = f"model.layers.{l}" | ||||
|             for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: | ||||
|                 for k in ["weight", "scales", "biases"]: | ||||
|                     if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights: | ||||
|                         to_join = [ | ||||
|                             weights.pop( | ||||
|                                 f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}" | ||||
|                             ) | ||||
|                             for e in range(self.args.num_local_experts) | ||||
|                         ] | ||||
|                         weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = ( | ||||
|                             mx.stack(to_join) | ||||
|                         ) | ||||
|  | ||||
|         return weights | ||||
|  | ||||
|     @property | ||||
|     def layers(self): | ||||
|         return self.model.layers | ||||
|  | ||||
|     @property | ||||
|     def head_dim(self): | ||||
|         return self.args.hidden_size // self.args.num_attention_heads | ||||
|  | ||||
|     @property | ||||
|     def n_kv_heads(self): | ||||
|         return self.args.num_key_value_heads | ||||
| @@ -16,6 +16,8 @@ class SuScaledRotaryEmbedding(nn.Module): | ||||
|         original_max_position_embeddings: int = 4096, | ||||
|         short_factor: Union[List[float], float] = 1.0, | ||||
|         long_factor: Union[List[float], float] = 1.0, | ||||
|         short_mscale: float = None, | ||||
|         long_mscale: float = None, | ||||
|     ): | ||||
|         """ | ||||
|         Phi3Su Scaled Rotary Embedding layer for Phi-3 models. | ||||
| @@ -37,12 +39,14 @@ class SuScaledRotaryEmbedding(nn.Module): | ||||
|             long_factor (float or list[float], optional): List of scaling | ||||
|               factors for sequences of length greater than | ||||
|               ``original_max_position_embeddings``.  Default: ``1.0``. | ||||
|             short_mscale (float, optional): Scale the input prior to embedding. | ||||
|             long_mscale (float, optional): Scale the input prior to embedding. | ||||
|         """ | ||||
|         super().__init__() | ||||
|         freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) | ||||
|         self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs | ||||
|         self.original_max_position_embeddings = original_max_position_embeddings | ||||
|         self.scale = math.sqrt( | ||||
|         self.scale = long_mscale or math.sqrt( | ||||
|             1 | ||||
|             + math.log(max_position_embeddings / original_max_position_embeddings) | ||||
|             / math.log(original_max_position_embeddings) | ||||
|   | ||||
| @@ -96,6 +96,7 @@ def linear_to_lora_layers( | ||||
|         "stablelm", | ||||
|         "qwen2", | ||||
|         "qwen2_moe", | ||||
|         "phimoe", | ||||
|         "gemma", | ||||
|         "gemma2", | ||||
|         "starcoder2", | ||||
| @@ -104,7 +105,7 @@ def linear_to_lora_layers( | ||||
|         "deepseek", | ||||
|     ]: | ||||
|         keys = set(["self_attn.q_proj", "self_attn.v_proj"]) | ||||
|         if model.model_type == "mixtral": | ||||
|         if model.model_type in ["mixtral", "phimoe"]: | ||||
|             keys.add("block_sparse_moe.gate") | ||||
|         if model.model_type == "qwen2_moe": | ||||
|             keys.add("mlp.gate") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Prince Canuma
					Prince Canuma