diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index 46ee6ab3..96ce4f85 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -2,6 +2,7 @@ import math from dataclasses import dataclass +from functools import partial from typing import Any, Dict, Optional, Tuple import mlx.core as mx @@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module): ) +# A clipped silu to prevent fp16 from overflowing +@partial(mx.compile, shapeless=True) +def clipped_silu(x): + return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100) + + class DeepseekV3Attention(nn.Module): def __init__(self, config: ModelArgs): super().__init__() @@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module): self.config = config self.num_experts_per_tok = config.num_experts_per_tok self.switch_mlp = SwitchGLU( - config.hidden_size, config.moe_intermediate_size, config.n_routed_experts + config.hidden_size, + config.moe_intermediate_size, + config.n_routed_experts, + activation=clipped_silu, ) self.gate = MoEGate(config) @@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module): r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - # Protect against overflow for fp16 - if out.dtype == mx.float16: - out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000) - return out + return h + r class DeepseekV3Model(nn.Module): diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index c48a32cf..23e08d97 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -147,11 +147,11 @@ def min_p_sampling( logprobs = logprobs * (1 / temperature) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logprobs).squeeze(0) - sorted_logprobs = logprobs[..., sorted_indices] + sorted_indices = mx.argsort(-logprobs, axis=-1) + sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1) # Top probability - top_logprobs = logprobs[..., sorted_indices[0]] + top_logprobs = sorted_logprobs[:, 0:1] # Calculate the min_p threshold scaled_min_p = top_logprobs + math.log(min_p) @@ -163,9 +163,9 @@ def min_p_sampling( # Create pool of tokens with probability less than scaled min_p selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) - # Return sampled token - sorted_token = mx.random.categorical(selected_logprobs) - return sorted_indices[sorted_token] + # Return sampled tokens + sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None] + return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr # sort probs in ascending order sorted_indices = mx.argsort(probs, axis=-1) - sorted_probs = probs[..., sorted_indices.squeeze(0)] + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) cumulative_probs = mx.cumsum(sorted_probs, axis=-1) @@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr 0, ) - sorted_token = mx.random.categorical(mx.log(top_probs)) - token = sorted_indices.squeeze(0)[sorted_token] - - return token + sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None] + return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 4523e3ae..de02704d 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): return prompt.rstrip() +def process_message_content(messages): + """ + Convert message content to a format suitable for `apply_chat_template`. + + The function operates on messages in place. It converts the 'content' field + to a string instead of a list of text fragments. + + Args: + message_list (list): A list of dictionaries, where each dictionary may + have a 'content' key containing a list of dictionaries with 'type' and + 'text' keys. + + Raises: + ValueError: If the 'content' type is not supported or if 'text' is missing. + + """ + for message in messages: + content = message["content"] + if isinstance(content, list): + text_fragments = [ + fragment["text"] for fragment in content if fragment["type"] == "text" + ] + if len(text_fragments) != len(content): + raise ValueError("Only 'text' content type is supported.") + message["content"] = "".join(text_fragments) + + @dataclass class PromptCache: cache: List[Any] = field(default_factory=list) @@ -591,8 +618,10 @@ class APIHandler(BaseHTTPRequestHandler): self.request_id = f"chatcmpl-{uuid.uuid4()}" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" if self.tokenizer.chat_template: + messages = body["messages"] + process_message_content(messages) prompt = self.tokenizer.apply_chat_template( - body["messages"], + messages, body.get("tools", None), add_generation_prompt=True, ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b9037295..0150f1b7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -398,8 +398,9 @@ def speculative_generate_step( quantize_cache_fn(cache) logprobs = logits - mx.logsumexp(logits, keepdims=True) - y = sampler(logprobs).squeeze(0) - return y, logprobs.squeeze(0) + logprobs = logprobs.squeeze(0) + y = sampler(logprobs) + return y, logprobs def _prefill(model, cache, y): while y.size > prefill_step_size: diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index c45fa443..f12abbf4 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase): token = top_p_sampling(logits, 0.95, temperature).item() self.assertTrue(token in (1, 2, 3)) + # Batch mode works + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + logits = mx.log(probs) + tokens = top_p_sampling(logits, 0.5, temperature) + self.assertEqual(tokens.tolist(), [0, 1]) + def test_min_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) @@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase): token = min_p_sampling(logits, 0.05) self.assertTrue(token in (0, 3)) + # Batch mode works + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + logits = mx.log(probs) + tokens = min_p_sampling(logits, 0.7) + self.assertEqual(tokens.tolist(), [0, 1]) + def test_top_k_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index ad17554d..ecf95f78 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -80,6 +80,29 @@ class TestServer(unittest.TestCase): self.assertIn("id", response_body) self.assertIn("choices", response_body) + def test_handle_chat_completions_with_content_fragments(self): + url = f"http://localhost:{self.port}/v1/chat/completions" + chat_post_data = { + "model": "chat_model", + "max_tokens": 10, + "temperature": 0.7, + "top_p": 0.85, + "repetition_penalty": 1.2, + "messages": [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."} + ], + }, + {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, + ], + } + response = requests.post(url, json=chat_post_data) + response_body = response.text + self.assertIn("id", response_body) + self.assertIn("choices", response_body) + def test_handle_models(self): url = f"http://localhost:{self.port}/v1/models" response = requests.get(url)