mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
better overflow correction (#1229)
This commit is contained in:
parent
7a83077cd7
commit
e8afb59de4
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# A clipped silu to prevent fp16 from overflowing
|
||||||
|
@partial(mx.compile, shapeless=True)
|
||||||
|
def clipped_silu(x):
|
||||||
|
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Attention(nn.Module):
|
class DeepseekV3Attention(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.num_experts_per_tok = config.num_experts_per_tok
|
self.num_experts_per_tok = config.num_experts_per_tok
|
||||||
self.switch_mlp = SwitchGLU(
|
self.switch_mlp = SwitchGLU(
|
||||||
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
config.hidden_size,
|
||||||
|
config.moe_intermediate_size,
|
||||||
|
config.n_routed_experts,
|
||||||
|
activation=clipped_silu,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = MoEGate(config)
|
self.gate = MoEGate(config)
|
||||||
@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|||||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
return h + r
|
||||||
# Protect against overflow for fp16
|
|
||||||
if out.dtype == mx.float16:
|
|
||||||
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3Model(nn.Module):
|
class DeepseekV3Model(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user