better overflow correction (#1229)

This commit is contained in:
Awni Hannun 2025-01-28 14:37:30 -08:00 committed by GitHub
parent 7a83077cd7
commit e8afb59de4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
)
# A clipped silu to prevent fp16 from overflowing
@partial(mx.compile, shapeless=True)
def clipped_silu(x):
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module):
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
config.hidden_size,
config.moe_intermediate_size,
config.n_routed_experts,
activation=clipped_silu,
)
self.gate = MoEGate(config)
@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module):
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out
return h + r
class DeepseekV3Model(nn.Module):