mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-22 03:21:16 +08:00
Huge speed improvement in validation mode.
This commit is contained in:
parent
2f20107d9b
commit
6086137131
@ -44,11 +44,9 @@ class GRPODataset:
|
||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of examples in the dataset."""
|
||||
return len(self._data)
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||
|
||||
from ..utils import generate_step
|
||||
|
||||
@dataclass
|
||||
class GRPOTrainingArgs(TrainingArgs):
|
||||
@ -61,7 +61,6 @@ def r1_extract_xml_answer(text: str) -> str:
|
||||
print("r1_extract_xml_answer returned empty string")
|
||||
return ""
|
||||
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
@ -113,90 +112,45 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
||||
return scores
|
||||
|
||||
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size):
|
||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size):
|
||||
if len(prompts.shape) == 1:
|
||||
prompts = prompts[None, :]
|
||||
if prompts.shape[1] == 0:
|
||||
return None
|
||||
|
||||
start_time = time.perf_counter()
|
||||
tokens_generated = 0
|
||||
model.eval()
|
||||
batch_size = prompts.shape[0] * group_size
|
||||
|
||||
# Repeat each prompt group_size times
|
||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||
|
||||
end_sequence = mx.array(tokenizer.encode("</answer>"))
|
||||
end_len = len(end_sequence)
|
||||
initial_length = prompts.shape[1]
|
||||
|
||||
# Initialize output tensor for all sequences
|
||||
output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
|
||||
output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
|
||||
current_lengths = mx.array([initial_length] * batch_size)
|
||||
|
||||
temp_factor = 1/temperature if temperature > 0 else float('inf')
|
||||
results = []
|
||||
tokens_generated = 0
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
not_finished = mx.ones((batch_size,), dtype=mx.bool_)
|
||||
for idx in range(batch_size):
|
||||
current_tokens = []
|
||||
generator = generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||
)
|
||||
|
||||
for _ in range(max_tokens):
|
||||
# Check if all sequences are finished
|
||||
if not mx.sum(not_finished).item():
|
||||
# Collect all tokens first
|
||||
for tokens, _ in generator:
|
||||
current_tokens.append(tokens)
|
||||
tokens_generated += 1
|
||||
if tokens == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
# Get model outputs for all sequences
|
||||
max_len = mx.max(current_lengths).item()
|
||||
batch_inputs = output[:, :max_len]
|
||||
logits = model(batch_inputs)[:, -1]
|
||||
# Convert to array after collection
|
||||
results.append(mx.array(current_tokens))
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Apply mask to logits
|
||||
logits = logits * mx.expand_dims(not_finished, -1)
|
||||
|
||||
# Sample next tokens
|
||||
logits *= temp_factor
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
next_tokens = mx.random.categorical(logprobs)
|
||||
|
||||
# Update outputs for active sequences
|
||||
for idx in range(batch_size):
|
||||
if not_finished[idx].item():
|
||||
curr_len = current_lengths[idx].item()
|
||||
token_value = next_tokens[idx].item()
|
||||
|
||||
# Create new arrays with updates
|
||||
output = mx.array(output.tolist()) # Make a copy
|
||||
output[idx, curr_len] = token_value
|
||||
current_lengths = mx.array([
|
||||
l + 1 if i == idx else l
|
||||
for i, l in enumerate(current_lengths.tolist())
|
||||
])
|
||||
tokens_generated += 1
|
||||
|
||||
# Check end conditions
|
||||
if token_value == tokenizer.eos_token_id:
|
||||
not_finished = mx.array([
|
||||
False if i == idx else nf
|
||||
for i, nf in enumerate(not_finished.tolist())
|
||||
])
|
||||
elif curr_len >= end_len:
|
||||
last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
|
||||
if mx.array_equal(last_tokens, end_sequence):
|
||||
not_finished = mx.array([
|
||||
False if i == idx else nf
|
||||
for i, nf in enumerate(not_finished.tolist())
|
||||
])
|
||||
|
||||
if _ % 32 == 0:
|
||||
mx.eval(output, current_lengths, not_finished)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
generation_time = end_time - start_time
|
||||
tokens_per_second = tokens_generated / generation_time
|
||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
|
||||
|
||||
# Return only the valid part of each sequence
|
||||
results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
|
||||
# Final evaluation of all results
|
||||
mx.eval(results)
|
||||
generation_time = time.perf_counter() - start_time
|
||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
@ -214,11 +168,12 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
seq_logits = logits[i, :seq_len]
|
||||
seq_targets = targets[i, :seq_len]
|
||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
seq_targets.reshape(seq_len, 1),
|
||||
axis=-1
|
||||
seq_targets.reshape(seq_len, 1), axis=-1
|
||||
).squeeze(-1)
|
||||
|
||||
per_token_logps.append(token_log_probs)
|
||||
mx.eval(logits)
|
||||
return per_token_logps
|
||||
@ -254,7 +209,6 @@ def grpo_loss(
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
temperature,
|
||||
group_size
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user