Huge speed improvement in validation mode.

This commit is contained in:
Goekdeniz-Guelmez 2025-02-21 22:08:49 +01:00
parent 2f20107d9b
commit 6086137131
2 changed files with 35 additions and 83 deletions

View File

@ -30,7 +30,7 @@ class GRPODataset:
prompt_tokens = tokenizer.apply_chat_template(
[
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str}
{'role': 'user', 'content': prompt_str}
],
add_generation_prompt=True
)
@ -44,11 +44,9 @@ class GRPODataset:
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
return self._data[idx]
def __len__(self) -> int:
"""Returns the number of examples in the dataset."""
return len(self._data)

View File

@ -12,7 +12,7 @@ import mlx.nn as nn
import numpy as np
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from ..utils import generate_step
@dataclass
class GRPOTrainingArgs(TrainingArgs):
@ -61,7 +61,6 @@ def r1_extract_xml_answer(text: str) -> str:
print("r1_extract_xml_answer returned empty string")
return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
@ -111,92 +110,47 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
scores.append(max(0.0, count))
return scores
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size):
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size):
if len(prompts.shape) == 1:
prompts = prompts[None, :]
if prompts.shape[1] == 0:
return None
start_time = time.perf_counter()
tokens_generated = 0
model.eval()
batch_size = prompts.shape[0] * group_size
# Repeat each prompt group_size times
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
end_sequence = mx.array(tokenizer.encode("</answer>"))
end_len = len(end_sequence)
initial_length = prompts.shape[1]
# Initialize output tensor for all sequences
output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
current_lengths = mx.array([initial_length] * batch_size)
temp_factor = 1/temperature if temperature > 0 else float('inf')
results = []
tokens_generated = 0
start_time = time.perf_counter()
try:
not_finished = mx.ones((batch_size,), dtype=mx.bool_)
for _ in range(max_tokens):
# Check if all sequences are finished
if not mx.sum(not_finished).item():
break
# Get model outputs for all sequences
max_len = mx.max(current_lengths).item()
batch_inputs = output[:, :max_len]
logits = model(batch_inputs)[:, -1]
for idx in range(batch_size):
current_tokens = []
generator = generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens,
sampler=lambda x: mx.argmax(x, axis=-1)
)
# Apply mask to logits
logits = logits * mx.expand_dims(not_finished, -1)
# Collect all tokens first
for tokens, _ in generator:
current_tokens.append(tokens)
tokens_generated += 1
if tokens == tokenizer.eos_token_id:
break
# Sample next tokens
logits *= temp_factor
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
next_tokens = mx.random.categorical(logprobs)
# Update outputs for active sequences
for idx in range(batch_size):
if not_finished[idx].item():
curr_len = current_lengths[idx].item()
token_value = next_tokens[idx].item()
# Create new arrays with updates
output = mx.array(output.tolist()) # Make a copy
output[idx, curr_len] = token_value
current_lengths = mx.array([
l + 1 if i == idx else l
for i, l in enumerate(current_lengths.tolist())
])
tokens_generated += 1
# Check end conditions
if token_value == tokenizer.eos_token_id:
not_finished = mx.array([
False if i == idx else nf
for i, nf in enumerate(not_finished.tolist())
])
elif curr_len >= end_len:
last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
if mx.array_equal(last_tokens, end_sequence):
not_finished = mx.array([
False if i == idx else nf
for i, nf in enumerate(not_finished.tolist())
])
if _ % 32 == 0:
mx.eval(output, current_lengths, not_finished)
# Convert to array after collection
results.append(mx.array(current_tokens))
mx.metal.clear_cache()
end_time = time.perf_counter()
generation_time = end_time - start_time
tokens_per_second = tokens_generated / generation_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
# Return only the valid part of each sequence
results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
# Final evaluation of all results
mx.eval(results)
generation_time = time.perf_counter() - start_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
return results
except Exception as e:
@ -214,13 +168,14 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
seq_logits = logits[i, :seq_len]
seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1)
token_log_probs = mx.take_along_axis(
log_probs,
seq_targets.reshape(seq_len, 1),
axis=-1
seq_targets.reshape(seq_len, 1), axis=-1
).squeeze(-1)
per_token_logps.append(token_log_probs)
mx.eval(logits)
mx.eval(logits)
return per_token_logps
@ -253,8 +208,7 @@ def grpo_loss(
model,
prompt_tensor,
max_tokens,
tokenizer,
temperature,
tokenizer,
group_size
)