diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index 46ee6ab3..96ce4f85 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -2,6 +2,7 @@ import math from dataclasses import dataclass +from functools import partial from typing import Any, Dict, Optional, Tuple import mlx.core as mx @@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module): ) +# A clipped silu to prevent fp16 from overflowing +@partial(mx.compile, shapeless=True) +def clipped_silu(x): + return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100) + + class DeepseekV3Attention(nn.Module): def __init__(self, config: ModelArgs): super().__init__() @@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module): self.config = config self.num_experts_per_tok = config.num_experts_per_tok self.switch_mlp = SwitchGLU( - config.hidden_size, config.moe_intermediate_size, config.n_routed_experts + config.hidden_size, + config.moe_intermediate_size, + config.n_routed_experts, + activation=clipped_silu, ) self.gate = MoEGate(config) @@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module): r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - # Protect against overflow for fp16 - if out.dtype == mx.float16: - out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000) - return out + return h + r class DeepseekV3Model(nn.Module):