mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-19 17:41:11 +08:00
Huge speed improvement in validation mode.
This commit is contained in:
parent
2f20107d9b
commit
6086137131
@ -30,7 +30,7 @@ class GRPODataset:
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
[
|
||||
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||
{'role': 'user', 'content': prompt_str}
|
||||
{'role': 'user', 'content': prompt_str}
|
||||
],
|
||||
add_generation_prompt=True
|
||||
)
|
||||
@ -44,11 +44,9 @@ class GRPODataset:
|
||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of examples in the dataset."""
|
||||
return len(self._data)
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||
|
||||
from ..utils import generate_step
|
||||
|
||||
@dataclass
|
||||
class GRPOTrainingArgs(TrainingArgs):
|
||||
@ -61,7 +61,6 @@ def r1_extract_xml_answer(text: str) -> str:
|
||||
print("r1_extract_xml_answer returned empty string")
|
||||
return ""
|
||||
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
@ -111,92 +110,47 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||
scores.append(max(0.0, count))
|
||||
return scores
|
||||
|
||||
|
||||
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size):
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size):
|
||||
if len(prompts.shape) == 1:
|
||||
prompts = prompts[None, :]
|
||||
if prompts.shape[1] == 0:
|
||||
return None
|
||||
|
||||
start_time = time.perf_counter()
|
||||
tokens_generated = 0
|
||||
model.eval()
|
||||
batch_size = prompts.shape[0] * group_size
|
||||
|
||||
# Repeat each prompt group_size times
|
||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||
|
||||
end_sequence = mx.array(tokenizer.encode("</answer>"))
|
||||
end_len = len(end_sequence)
|
||||
initial_length = prompts.shape[1]
|
||||
|
||||
# Initialize output tensor for all sequences
|
||||
output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
|
||||
output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
|
||||
current_lengths = mx.array([initial_length] * batch_size)
|
||||
|
||||
temp_factor = 1/temperature if temperature > 0 else float('inf')
|
||||
|
||||
results = []
|
||||
tokens_generated = 0
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
not_finished = mx.ones((batch_size,), dtype=mx.bool_)
|
||||
|
||||
for _ in range(max_tokens):
|
||||
# Check if all sequences are finished
|
||||
if not mx.sum(not_finished).item():
|
||||
break
|
||||
|
||||
# Get model outputs for all sequences
|
||||
max_len = mx.max(current_lengths).item()
|
||||
batch_inputs = output[:, :max_len]
|
||||
logits = model(batch_inputs)[:, -1]
|
||||
for idx in range(batch_size):
|
||||
current_tokens = []
|
||||
generator = generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||
)
|
||||
|
||||
# Apply mask to logits
|
||||
logits = logits * mx.expand_dims(not_finished, -1)
|
||||
# Collect all tokens first
|
||||
for tokens, _ in generator:
|
||||
current_tokens.append(tokens)
|
||||
tokens_generated += 1
|
||||
if tokens == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
# Sample next tokens
|
||||
logits *= temp_factor
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
next_tokens = mx.random.categorical(logprobs)
|
||||
|
||||
# Update outputs for active sequences
|
||||
for idx in range(batch_size):
|
||||
if not_finished[idx].item():
|
||||
curr_len = current_lengths[idx].item()
|
||||
token_value = next_tokens[idx].item()
|
||||
|
||||
# Create new arrays with updates
|
||||
output = mx.array(output.tolist()) # Make a copy
|
||||
output[idx, curr_len] = token_value
|
||||
current_lengths = mx.array([
|
||||
l + 1 if i == idx else l
|
||||
for i, l in enumerate(current_lengths.tolist())
|
||||
])
|
||||
tokens_generated += 1
|
||||
|
||||
# Check end conditions
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
not_finished = mx.array([
|
||||
False if i == idx else nf
|
||||
for i, nf in enumerate(not_finished.tolist())
|
||||
])
|
||||
elif curr_len >= end_len:
|
||||
last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
|
||||
if mx.array_equal(last_tokens, end_sequence):
|
||||
not_finished = mx.array([
|
||||
False if i == idx else nf
|
||||
for i, nf in enumerate(not_finished.tolist())
|
||||
])
|
||||
|
||||
if _ % 32 == 0:
|
||||
mx.eval(output, current_lengths, not_finished)
|
||||
# Convert to array after collection
|
||||
results.append(mx.array(current_tokens))
|
||||
mx.metal.clear_cache()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
generation_time = end_time - start_time
|
||||
tokens_per_second = tokens_generated / generation_time
|
||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
|
||||
|
||||
# Return only the valid part of each sequence
|
||||
results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
|
||||
# Final evaluation of all results
|
||||
mx.eval(results)
|
||||
generation_time = time.perf_counter() - start_time
|
||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
@ -214,13 +168,14 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
seq_logits = logits[i, :seq_len]
|
||||
seq_targets = targets[i, :seq_len]
|
||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
seq_targets.reshape(seq_len, 1),
|
||||
axis=-1
|
||||
seq_targets.reshape(seq_len, 1), axis=-1
|
||||
).squeeze(-1)
|
||||
|
||||
per_token_logps.append(token_log_probs)
|
||||
mx.eval(logits)
|
||||
mx.eval(logits)
|
||||
return per_token_logps
|
||||
|
||||
|
||||
@ -253,8 +208,7 @@ def grpo_loss(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
temperature,
|
||||
tokenizer,
|
||||
group_size
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user