Merge branch 'ml-explore:main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez 2025-01-29 15:07:11 +01:00 committed by GitHub
commit 57e10446b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 89 additions and 20 deletions

View File

@ -2,6 +2,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import mlx.core as mx import mlx.core as mx
@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
) )
# A clipped silu to prevent fp16 from overflowing
@partial(mx.compile, shapeless=True)
def clipped_silu(x):
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
class DeepseekV3Attention(nn.Module): class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module):
self.config = config self.config = config
self.num_experts_per_tok = config.num_experts_per_tok self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU( self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts config.hidden_size,
config.moe_intermediate_size,
config.n_routed_experts,
activation=clipped_silu,
) )
self.gate = MoEGate(config) self.gate = MoEGate(config)
@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module):
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r
r = self.mlp(self.post_attention_layernorm(h)) r = self.mlp(self.post_attention_layernorm(h))
out = h + r return h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out
class DeepseekV3Model(nn.Module): class DeepseekV3Model(nn.Module):

View File

@ -147,11 +147,11 @@ def min_p_sampling(
logprobs = logprobs * (1 / temperature) logprobs = logprobs * (1 / temperature)
# Indices sorted in decreasing order # Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs).squeeze(0) sorted_indices = mx.argsort(-logprobs, axis=-1)
sorted_logprobs = logprobs[..., sorted_indices] sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
# Top probability # Top probability
top_logprobs = logprobs[..., sorted_indices[0]] top_logprobs = sorted_logprobs[:, 0:1]
# Calculate the min_p threshold # Calculate the min_p threshold
scaled_min_p = top_logprobs + math.log(min_p) scaled_min_p = top_logprobs + math.log(min_p)
@ -163,9 +163,9 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p # Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
# Return sampled token # Return sampled tokens
sorted_token = mx.random.categorical(selected_logprobs) sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
return sorted_indices[sorted_token] return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
# sort probs in ascending order # sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1) sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)] sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
cumulative_probs = mx.cumsum(sorted_probs, axis=-1) cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
0, 0,
) )
sorted_token = mx.random.categorical(mx.log(top_probs)) sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
token = sorted_indices.squeeze(0)[sorted_token] return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
return token
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)

View File

@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
def process_message_content(messages):
"""
Convert message content to a format suitable for `apply_chat_template`.
The function operates on messages in place. It converts the 'content' field
to a string instead of a list of text fragments.
Args:
message_list (list): A list of dictionaries, where each dictionary may
have a 'content' key containing a list of dictionaries with 'type' and
'text' keys.
Raises:
ValueError: If the 'content' type is not supported or if 'text' is missing.
"""
for message in messages:
content = message["content"]
if isinstance(content, list):
text_fragments = [
fragment["text"] for fragment in content if fragment["type"] == "text"
]
if len(text_fragments) != len(content):
raise ValueError("Only 'text' content type is supported.")
message["content"] = "".join(text_fragments)
@dataclass @dataclass
class PromptCache: class PromptCache:
cache: List[Any] = field(default_factory=list) cache: List[Any] = field(default_factory=list)
@ -591,8 +618,10 @@ class APIHandler(BaseHTTPRequestHandler):
self.request_id = f"chatcmpl-{uuid.uuid4()}" self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if self.tokenizer.chat_template: if self.tokenizer.chat_template:
messages = body["messages"]
process_message_content(messages)
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
body["messages"], messages,
body.get("tools", None), body.get("tools", None),
add_generation_prompt=True, add_generation_prompt=True,
) )

View File

@ -398,8 +398,9 @@ def speculative_generate_step(
quantize_cache_fn(cache) quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs).squeeze(0) logprobs = logprobs.squeeze(0)
return y, logprobs.squeeze(0) y = sampler(logprobs)
return y, logprobs
def _prefill(model, cache, y): def _prefill(model, cache, y):
while y.size > prefill_step_size: while y.size > prefill_step_size:

View File

@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase):
token = top_p_sampling(logits, 0.95, temperature).item() token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3)) self.assertTrue(token in (1, 2, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_p_sampling(logits, 0.5, temperature)
self.assertEqual(tokens.tolist(), [0, 1])
def test_min_p_sampling(self): def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs) logits = mx.log(probs)
@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase):
token = min_p_sampling(logits, 0.05) token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3)) self.assertTrue(token in (0, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = min_p_sampling(logits, 0.7)
self.assertEqual(tokens.tolist(), [0, 1])
def test_top_k_sampling(self): def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs) logits = mx.log(probs)

View File

@ -80,6 +80,29 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body) self.assertIn("id", response_body)
self.assertIn("choices", response_body) self.assertIn("choices", response_body)
def test_handle_chat_completions_with_content_fragments(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = {
"model": "chat_model",
"max_tokens": 10,
"temperature": 0.7,
"top_p": 0.85,
"repetition_penalty": 1.2,
"messages": [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
],
},
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
],
}
response = requests.post(url, json=chat_post_data)
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_models(self): def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models" url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url) response = requests.get(url)