Huge speed improvement in validation mode.

This commit is contained in:
Goekdeniz-Guelmez 2025-02-21 22:08:49 +01:00
parent 2f20107d9b
commit 6086137131
2 changed files with 35 additions and 83 deletions

View File

@ -44,11 +44,9 @@ class GRPODataset:
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
return self._data[idx]
def __len__(self) -> int:
"""Returns the number of examples in the dataset."""
return len(self._data)

View File

@ -12,7 +12,7 @@ import mlx.nn as nn
import numpy as np
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from ..utils import generate_step
@dataclass
class GRPOTrainingArgs(TrainingArgs):
@ -61,7 +61,6 @@ def r1_extract_xml_answer(text: str) -> str:
print("r1_extract_xml_answer returned empty string")
return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
if not completions:
return [0.0] * len(prompts)
@ -113,90 +112,45 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size):
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size):
if len(prompts.shape) == 1:
prompts = prompts[None, :]
if prompts.shape[1] == 0:
return None
start_time = time.perf_counter()
tokens_generated = 0
model.eval()
batch_size = prompts.shape[0] * group_size
# Repeat each prompt group_size times
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
end_sequence = mx.array(tokenizer.encode("</answer>"))
end_len = len(end_sequence)
initial_length = prompts.shape[1]
# Initialize output tensor for all sequences
output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
current_lengths = mx.array([initial_length] * batch_size)
temp_factor = 1/temperature if temperature > 0 else float('inf')
results = []
tokens_generated = 0
start_time = time.perf_counter()
try:
not_finished = mx.ones((batch_size,), dtype=mx.bool_)
for idx in range(batch_size):
current_tokens = []
generator = generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens,
sampler=lambda x: mx.argmax(x, axis=-1)
)
for _ in range(max_tokens):
# Check if all sequences are finished
if not mx.sum(not_finished).item():
# Collect all tokens first
for tokens, _ in generator:
current_tokens.append(tokens)
tokens_generated += 1
if tokens == tokenizer.eos_token_id:
break
# Get model outputs for all sequences
max_len = mx.max(current_lengths).item()
batch_inputs = output[:, :max_len]
logits = model(batch_inputs)[:, -1]
# Convert to array after collection
results.append(mx.array(current_tokens))
mx.metal.clear_cache()
# Apply mask to logits
logits = logits * mx.expand_dims(not_finished, -1)
# Sample next tokens
logits *= temp_factor
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
next_tokens = mx.random.categorical(logprobs)
# Update outputs for active sequences
for idx in range(batch_size):
if not_finished[idx].item():
curr_len = current_lengths[idx].item()
token_value = next_tokens[idx].item()
# Create new arrays with updates
output = mx.array(output.tolist()) # Make a copy
output[idx, curr_len] = token_value
current_lengths = mx.array([
l + 1 if i == idx else l
for i, l in enumerate(current_lengths.tolist())
])
tokens_generated += 1
# Check end conditions
if token_value == tokenizer.eos_token_id:
not_finished = mx.array([
False if i == idx else nf
for i, nf in enumerate(not_finished.tolist())
])
elif curr_len >= end_len:
last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
if mx.array_equal(last_tokens, end_sequence):
not_finished = mx.array([
False if i == idx else nf
for i, nf in enumerate(not_finished.tolist())
])
if _ % 32 == 0:
mx.eval(output, current_lengths, not_finished)
end_time = time.perf_counter()
generation_time = end_time - start_time
tokens_per_second = tokens_generated / generation_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
# Return only the valid part of each sequence
results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
# Final evaluation of all results
mx.eval(results)
generation_time = time.perf_counter() - start_time
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
return results
except Exception as e:
@ -214,11 +168,12 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
seq_logits = logits[i, :seq_len]
seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1)
token_log_probs = mx.take_along_axis(
log_probs,
seq_targets.reshape(seq_len, 1),
axis=-1
seq_targets.reshape(seq_len, 1), axis=-1
).squeeze(-1)
per_token_logps.append(token_log_probs)
mx.eval(logits)
return per_token_logps
@ -254,7 +209,6 @@ def grpo_loss(
prompt_tensor,
max_tokens,
tokenizer,
temperature,
group_size
)