Huge speed improvement in validation mode.

This commit is contained in:
Goekdeniz-Guelmez 2025-02-21 22:08:49 +01:00
parent 2f20107d9b
commit 6086137131
2 changed files with 35 additions and 83 deletions

View File

@ -30,7 +30,7 @@ class GRPODataset:
prompt_tokens = tokenizer.apply_chat_template( prompt_tokens = tokenizer.apply_chat_template(
[ [
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""}, {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str} {'role': 'user', 'content': prompt_str}
], ],
add_generation_prompt=True add_generation_prompt=True
) )
@ -44,11 +44,9 @@ class GRPODataset:
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str)) self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]: def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
return self._data[idx] return self._data[idx]
def __len__(self) -> int: def __len__(self) -> int:
"""Returns the number of examples in the dataset."""
return len(self._data) return len(self._data)

View File

@ -12,7 +12,7 @@ import mlx.nn as nn
import numpy as np import numpy as np
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from ..utils import generate_step
@dataclass @dataclass
class GRPOTrainingArgs(TrainingArgs): class GRPOTrainingArgs(TrainingArgs):
@ -61,7 +61,6 @@ def r1_extract_xml_answer(text: str) -> str:
print("r1_extract_xml_answer returned empty string") print("r1_extract_xml_answer returned empty string")
return "" return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
@ -113,90 +112,45 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores return scores
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size): def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size):
if len(prompts.shape) == 1: if len(prompts.shape) == 1:
prompts = prompts[None, :] prompts = prompts[None, :]
if prompts.shape[1] == 0: if prompts.shape[1] == 0:
return None return None
start_time = time.perf_counter() model.eval()
tokens_generated = 0
batch_size = prompts.shape[0] * group_size batch_size = prompts.shape[0] * group_size
# Repeat each prompt group_size times
expanded_prompts = mx.repeat(prompts, group_size, axis=0) expanded_prompts = mx.repeat(prompts, group_size, axis=0)
end_sequence = mx.array(tokenizer.encode("</answer>")) results = []
end_len = len(end_sequence) tokens_generated = 0
initial_length = prompts.shape[1] start_time = time.perf_counter()
# Initialize output tensor for all sequences
output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
current_lengths = mx.array([initial_length] * batch_size)
temp_factor = 1/temperature if temperature > 0 else float('inf')
try: try:
not_finished = mx.ones((batch_size,), dtype=mx.bool_) for idx in range(batch_size):
current_tokens = []
generator = generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens,
sampler=lambda x: mx.argmax(x, axis=-1)
)
for _ in range(max_tokens): # Collect all tokens first
# Check if all sequences are finished for tokens, _ in generator:
if not mx.sum(not_finished).item(): current_tokens.append(tokens)
break tokens_generated += 1
if tokens == tokenizer.eos_token_id:
break
# Get model outputs for all sequences # Convert to array after collection
max_len = mx.max(current_lengths).item() results.append(mx.array(current_tokens))
batch_inputs = output[:, :max_len] mx.metal.clear_cache()
logits = model(batch_inputs)[:, -1]
# Apply mask to logits # Final evaluation of all results
logits = logits * mx.expand_dims(not_finished, -1) mx.eval(results)
generation_time = time.perf_counter() - start_time
# Sample next tokens print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
logits *= temp_factor
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
next_tokens = mx.random.categorical(logprobs)
# Update outputs for active sequences
for idx in range(batch_size):
if not_finished[idx].item():
curr_len = current_lengths[idx].item()
token_value = next_tokens[idx].item()
# Create new arrays with updates
output = mx.array(output.tolist()) # Make a copy
output[idx, curr_len] = token_value
current_lengths = mx.array([
l + 1 if i == idx else l
for i, l in enumerate(current_lengths.tolist())
])
tokens_generated += 1
# Check end conditions
if token_value == tokenizer.eos_token_id:
not_finished = mx.array([
False if i == idx else nf
for i, nf in enumerate(not_finished.tolist())
])
elif curr_len >= end_len:
last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
if mx.array_equal(last_tokens, end_sequence):
not_finished = mx.array([
False if i == idx else nf
for i, nf in enumerate(not_finished.tolist())
])
if _ % 32 == 0:
mx.eval(output, current_lengths, not_finished)
end_time = time.perf_counter()
generation_time = end_time - start_time
tokens_per_second = tokens_generated / generation_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
# Return only the valid part of each sequence
results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
return results return results
except Exception as e: except Exception as e:
@ -214,13 +168,14 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
seq_logits = logits[i, :seq_len] seq_logits = logits[i, :seq_len]
seq_targets = targets[i, :seq_len] seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1) log_probs = nn.log_softmax(seq_logits, axis=-1)
token_log_probs = mx.take_along_axis( token_log_probs = mx.take_along_axis(
log_probs, log_probs,
seq_targets.reshape(seq_len, 1), seq_targets.reshape(seq_len, 1), axis=-1
axis=-1
).squeeze(-1) ).squeeze(-1)
per_token_logps.append(token_log_probs) per_token_logps.append(token_log_probs)
mx.eval(logits) mx.eval(logits)
return per_token_logps return per_token_logps
@ -254,7 +209,6 @@ def grpo_loss(
prompt_tensor, prompt_tensor,
max_tokens, max_tokens,
tokenizer, tokenizer,
temperature,
group_size group_size
) )