mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
support kimi + more options in chat mode (#1312)
This commit is contained in:
@@ -181,30 +181,37 @@ class DeepseekV3Attention(nn.Module):
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
if self.config.rope_scaling is not None:
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.scale = self.scale * mscale * mscale
|
||||
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rope = DeepseekV3YarnRotaryEmbedding(
|
||||
dim=self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**rope_kwargs,
|
||||
)
|
||||
rope_kwargs = {
|
||||
key: self.config.rope_scaling[key]
|
||||
for key in [
|
||||
"original_max_position_embeddings",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
]
|
||||
if key in self.config.rope_scaling
|
||||
}
|
||||
self.rope = DeepseekV3YarnRotaryEmbedding(
|
||||
dim=self.qk_rope_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
base=self.rope_theta,
|
||||
**rope_kwargs,
|
||||
)
|
||||
else:
|
||||
self.rope = nn.RoPE(
|
||||
dims=self.qk_rope_head_dim,
|
||||
base=self.rope_theta,
|
||||
traditional=True,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -487,8 +494,12 @@ class Model(nn.Module):
|
||||
]
|
||||
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
|
||||
|
||||
# Remove multi-token prediction layer
|
||||
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
|
||||
# Remove multi-token prediction layer and any unused precomputed rotary freqs
|
||||
return {
|
||||
k: v
|
||||
for k, v in weights.items()
|
||||
if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
|
||||
Reference in New Issue
Block a user