better overflow correction (#1229)

This commit is contained in:
Awni Hannun 2025-01-28 14:37:30 -08:00 committed by GitHub
parent 7a83077cd7
commit e8afb59de4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import mlx.core as mx import mlx.core as mx
@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
) )
# A clipped silu to prevent fp16 from overflowing
@partial(mx.compile, shapeless=True)
def clipped_silu(x):
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
class DeepseekV3Attention(nn.Module): class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module):
self.config = config self.config = config
self.num_experts_per_tok = config.num_experts_per_tok self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU( self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts config.hidden_size,
config.moe_intermediate_size,
config.n_routed_experts,
activation=clipped_silu,
) )
self.gate = MoEGate(config) self.gate = MoEGate(config)
@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module):
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r
r = self.mlp(self.post_attention_layernorm(h)) r = self.mlp(self.post_attention_layernorm(h))
out = h + r return h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out
class DeepseekV3Model(nn.Module): class DeepseekV3Model(nn.Module):