mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-19 03:18:06 +08:00
Merge branch 'ml-explore:main' into adding-dpo-training
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# A clipped silu to prevent fp16 from overflowing
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def clipped_silu(x):
|
||||
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
|
||||
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
@@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module):
|
||||
self.config = config
|
||||
self.num_experts_per_tok = config.num_experts_per_tok
|
||||
self.switch_mlp = SwitchGLU(
|
||||
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
|
||||
config.hidden_size,
|
||||
config.moe_intermediate_size,
|
||||
config.n_routed_experts,
|
||||
activation=clipped_silu,
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config)
|
||||
@@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
out = h + r
|
||||
# Protect against overflow for fp16
|
||||
if out.dtype == mx.float16:
|
||||
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
|
||||
return out
|
||||
return h + r
|
||||
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
|
@@ -147,11 +147,11 @@ def min_p_sampling(
|
||||
logprobs = logprobs * (1 / temperature)
|
||||
|
||||
# Indices sorted in decreasing order
|
||||
sorted_indices = mx.argsort(-logprobs).squeeze(0)
|
||||
sorted_logprobs = logprobs[..., sorted_indices]
|
||||
sorted_indices = mx.argsort(-logprobs, axis=-1)
|
||||
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
|
||||
|
||||
# Top probability
|
||||
top_logprobs = logprobs[..., sorted_indices[0]]
|
||||
top_logprobs = sorted_logprobs[:, 0:1]
|
||||
|
||||
# Calculate the min_p threshold
|
||||
scaled_min_p = top_logprobs + math.log(min_p)
|
||||
@@ -163,9 +163,9 @@ def min_p_sampling(
|
||||
# Create pool of tokens with probability less than scaled min_p
|
||||
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
|
||||
|
||||
# Return sampled token
|
||||
sorted_token = mx.random.categorical(selected_logprobs)
|
||||
return sorted_indices[sorted_token]
|
||||
# Return sampled tokens
|
||||
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
|
||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
@@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
|
||||
# sort probs in ascending order
|
||||
sorted_indices = mx.argsort(probs, axis=-1)
|
||||
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
||||
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
||||
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||
|
||||
@@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
0,
|
||||
)
|
||||
|
||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||
token = sorted_indices.squeeze(0)[sorted_token]
|
||||
|
||||
return token
|
||||
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
|
||||
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
|
@@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
return prompt.rstrip()
|
||||
|
||||
|
||||
def process_message_content(messages):
|
||||
"""
|
||||
Convert message content to a format suitable for `apply_chat_template`.
|
||||
|
||||
The function operates on messages in place. It converts the 'content' field
|
||||
to a string instead of a list of text fragments.
|
||||
|
||||
Args:
|
||||
message_list (list): A list of dictionaries, where each dictionary may
|
||||
have a 'content' key containing a list of dictionaries with 'type' and
|
||||
'text' keys.
|
||||
|
||||
Raises:
|
||||
ValueError: If the 'content' type is not supported or if 'text' is missing.
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
text_fragments = [
|
||||
fragment["text"] for fragment in content if fragment["type"] == "text"
|
||||
]
|
||||
if len(text_fragments) != len(content):
|
||||
raise ValueError("Only 'text' content type is supported.")
|
||||
message["content"] = "".join(text_fragments)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptCache:
|
||||
cache: List[Any] = field(default_factory=list)
|
||||
@@ -591,8 +618,10 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
|
||||
if self.tokenizer.chat_template:
|
||||
messages = body["messages"]
|
||||
process_message_content(messages)
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
body["messages"],
|
||||
messages,
|
||||
body.get("tools", None),
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
@@ -398,8 +398,9 @@ def speculative_generate_step(
|
||||
quantize_cache_fn(cache)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs).squeeze(0)
|
||||
return y, logprobs.squeeze(0)
|
||||
logprobs = logprobs.squeeze(0)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs
|
||||
|
||||
def _prefill(model, cache, y):
|
||||
while y.size > prefill_step_size:
|
||||
|
Reference in New Issue
Block a user