Merge branch 'ml-explore:main' into adding-GRPO-training

This commit is contained in:
Gökdeniz Gülmez 2025-01-29 15:07:52 +01:00 committed by GitHub
commit b1e573d6e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 274 additions and 21 deletions

View File

@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples. - Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models. - Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`. - Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`. - Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.

View File

@ -2,6 +2,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import mlx.core as mx import mlx.core as mx
@ -125,6 +126,12 @@ class DeepseekV3YarnRotaryEmbedding(nn.Module):
) )
# A clipped silu to prevent fp16 from overflowing
@partial(mx.compile, shapeless=True)
def clipped_silu(x):
return mx.clip(x * mx.sigmoid(x), a_min=-100, a_max=100)
class DeepseekV3Attention(nn.Module): class DeepseekV3Attention(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
@ -312,7 +319,10 @@ class DeepseekV3MoE(nn.Module):
self.config = config self.config = config
self.num_experts_per_tok = config.num_experts_per_tok self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU( self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts config.hidden_size,
config.moe_intermediate_size,
config.n_routed_experts,
activation=clipped_silu,
) )
self.gate = MoEGate(config) self.gate = MoEGate(config)
@ -359,11 +369,7 @@ class DeepseekV3DecoderLayer(nn.Module):
r = self.self_attn(self.input_layernorm(x), mask, cache) r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r h = x + r
r = self.mlp(self.post_attention_layernorm(h)) r = self.mlp(self.post_attention_layernorm(h))
out = h + r return h + r
# Protect against overflow for fp16
if out.dtype == mx.float16:
out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000)
return out
class DeepseekV3Model(nn.Module): class DeepseekV3Model(nn.Module):

View File

@ -0,0 +1,183 @@
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
rms_norm_eps: float
vocab_size: int
attention_bias: bool
head_dim: int
max_position_embeddings: int
mlp_bias: bool
model_type: str
rope_theta: float
tie_word_embeddings: bool
class HeliumAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class HeliumMLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.intermediate_size = args.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=args.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=args.mlp_bias
)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class HeliumDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = HeliumAttention(args)
self.mlp = HeliumMLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class HeliumModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_hidden_layers = args.num_hidden_layers
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [HeliumDecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = HeliumModel(args)
self.vocab_size = args.vocab_size
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers

View File

@ -147,11 +147,11 @@ def min_p_sampling(
logprobs = logprobs * (1 / temperature) logprobs = logprobs * (1 / temperature)
# Indices sorted in decreasing order # Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs).squeeze(0) sorted_indices = mx.argsort(-logprobs, axis=-1)
sorted_logprobs = logprobs[..., sorted_indices] sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
# Top probability # Top probability
top_logprobs = logprobs[..., sorted_indices[0]] top_logprobs = sorted_logprobs[:, 0:1]
# Calculate the min_p threshold # Calculate the min_p threshold
scaled_min_p = top_logprobs + math.log(min_p) scaled_min_p = top_logprobs + math.log(min_p)
@ -163,9 +163,9 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p # Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
# Return sampled token # Return sampled tokens
sorted_token = mx.random.categorical(selected_logprobs) sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
return sorted_indices[sorted_token] return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
@ -185,7 +185,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
# sort probs in ascending order # sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1) sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)] sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
cumulative_probs = mx.cumsum(sorted_probs, axis=-1) cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
@ -196,10 +196,8 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
0, 0,
) )
sorted_token = mx.random.categorical(mx.log(top_probs)) sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
token = sorted_indices.squeeze(0)[sorted_token] return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
return token
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)

View File

@ -114,6 +114,33 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
def process_message_content(messages):
"""
Convert message content to a format suitable for `apply_chat_template`.
The function operates on messages in place. It converts the 'content' field
to a string instead of a list of text fragments.
Args:
message_list (list): A list of dictionaries, where each dictionary may
have a 'content' key containing a list of dictionaries with 'type' and
'text' keys.
Raises:
ValueError: If the 'content' type is not supported or if 'text' is missing.
"""
for message in messages:
content = message["content"]
if isinstance(content, list):
text_fragments = [
fragment["text"] for fragment in content if fragment["type"] == "text"
]
if len(text_fragments) != len(content):
raise ValueError("Only 'text' content type is supported.")
message["content"] = "".join(text_fragments)
@dataclass @dataclass
class PromptCache: class PromptCache:
cache: List[Any] = field(default_factory=list) cache: List[Any] = field(default_factory=list)
@ -591,8 +618,10 @@ class APIHandler(BaseHTTPRequestHandler):
self.request_id = f"chatcmpl-{uuid.uuid4()}" self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion"
if self.tokenizer.chat_template: if self.tokenizer.chat_template:
messages = body["messages"]
process_message_content(messages)
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
body["messages"], messages,
body.get("tools", None), body.get("tools", None),
add_generation_prompt=True, add_generation_prompt=True,
) )

View File

@ -94,6 +94,7 @@ def linear_to_lora_layers(
"phimoe", "phimoe",
"gemma", "gemma",
"gemma2", "gemma2",
"helium",
"starcoder2", "starcoder2",
"cohere", "cohere",
"cohere2", "cohere2",

View File

@ -398,8 +398,9 @@ def speculative_generate_step(
quantize_cache_fn(cache) quantize_cache_fn(cache)
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs).squeeze(0) logprobs = logprobs.squeeze(0)
return y, logprobs.squeeze(0) y = sampler(logprobs)
return y, logprobs
def _prefill(model, cache, y): def _prefill(model, cache, y):
while y.size > prefill_step_size: while y.size > prefill_step_size:

View File

@ -28,6 +28,12 @@ class TestSampleUtils(unittest.TestCase):
token = top_p_sampling(logits, 0.95, temperature).item() token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3)) self.assertTrue(token in (1, 2, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_p_sampling(logits, 0.5, temperature)
self.assertEqual(tokens.tolist(), [0, 1])
def test_min_p_sampling(self): def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs) logits = mx.log(probs)
@ -42,6 +48,12 @@ class TestSampleUtils(unittest.TestCase):
token = min_p_sampling(logits, 0.05) token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3)) self.assertTrue(token in (0, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = min_p_sampling(logits, 0.7)
self.assertEqual(tokens.tolist(), [0, 1])
def test_top_k_sampling(self): def test_top_k_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs) logits = mx.log(probs)

View File

@ -80,6 +80,29 @@ class TestServer(unittest.TestCase):
self.assertIn("id", response_body) self.assertIn("id", response_body)
self.assertIn("choices", response_body) self.assertIn("choices", response_body)
def test_handle_chat_completions_with_content_fragments(self):
url = f"http://localhost:{self.port}/v1/chat/completions"
chat_post_data = {
"model": "chat_model",
"max_tokens": 10,
"temperature": 0.7,
"top_p": 0.85,
"repetition_penalty": 1.2,
"messages": [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
],
},
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
],
}
response = requests.post(url, json=chat_post_data)
response_body = response.text
self.assertIn("id", response_body)
self.assertIn("choices", response_body)
def test_handle_models(self): def test_handle_models(self):
url = f"http://localhost:{self.port}/v1/models" url = f"http://localhost:{self.port}/v1/models"
response = requests.get(url) response = requests.get(url)