mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-23 12:41:17 +08:00
Huge speed improvement in validation mode.
This commit is contained in:
parent
2f20107d9b
commit
6086137131
@ -44,11 +44,9 @@ class GRPODataset:
|
|||||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Returns the number of examples in the dataset."""
|
|
||||||
return len(self._data)
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||||
|
from ..utils import generate_step
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GRPOTrainingArgs(TrainingArgs):
|
class GRPOTrainingArgs(TrainingArgs):
|
||||||
@ -61,7 +61,6 @@ def r1_extract_xml_answer(text: str) -> str:
|
|||||||
print("r1_extract_xml_answer returned empty string")
|
print("r1_extract_xml_answer returned empty string")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
@ -113,90 +112,45 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size):
|
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size):
|
||||||
if len(prompts.shape) == 1:
|
if len(prompts.shape) == 1:
|
||||||
prompts = prompts[None, :]
|
prompts = prompts[None, :]
|
||||||
if prompts.shape[1] == 0:
|
if prompts.shape[1] == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
model.eval()
|
||||||
tokens_generated = 0
|
|
||||||
batch_size = prompts.shape[0] * group_size
|
batch_size = prompts.shape[0] * group_size
|
||||||
|
|
||||||
# Repeat each prompt group_size times
|
|
||||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||||
|
|
||||||
end_sequence = mx.array(tokenizer.encode("</answer>"))
|
results = []
|
||||||
end_len = len(end_sequence)
|
tokens_generated = 0
|
||||||
initial_length = prompts.shape[1]
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# Initialize output tensor for all sequences
|
|
||||||
output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
|
|
||||||
output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
|
|
||||||
current_lengths = mx.array([initial_length] * batch_size)
|
|
||||||
|
|
||||||
temp_factor = 1/temperature if temperature > 0 else float('inf')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
not_finished = mx.ones((batch_size,), dtype=mx.bool_)
|
for idx in range(batch_size):
|
||||||
|
current_tokens = []
|
||||||
|
generator = generate_step(
|
||||||
|
expanded_prompts[idx],
|
||||||
|
model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(max_tokens):
|
# Collect all tokens first
|
||||||
# Check if all sequences are finished
|
for tokens, _ in generator:
|
||||||
if not mx.sum(not_finished).item():
|
current_tokens.append(tokens)
|
||||||
|
tokens_generated += 1
|
||||||
|
if tokens == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Get model outputs for all sequences
|
# Convert to array after collection
|
||||||
max_len = mx.max(current_lengths).item()
|
results.append(mx.array(current_tokens))
|
||||||
batch_inputs = output[:, :max_len]
|
mx.metal.clear_cache()
|
||||||
logits = model(batch_inputs)[:, -1]
|
|
||||||
|
|
||||||
# Apply mask to logits
|
# Final evaluation of all results
|
||||||
logits = logits * mx.expand_dims(not_finished, -1)
|
mx.eval(results)
|
||||||
|
generation_time = time.perf_counter() - start_time
|
||||||
# Sample next tokens
|
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
||||||
logits *= temp_factor
|
|
||||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
|
||||||
next_tokens = mx.random.categorical(logprobs)
|
|
||||||
|
|
||||||
# Update outputs for active sequences
|
|
||||||
for idx in range(batch_size):
|
|
||||||
if not_finished[idx].item():
|
|
||||||
curr_len = current_lengths[idx].item()
|
|
||||||
token_value = next_tokens[idx].item()
|
|
||||||
|
|
||||||
# Create new arrays with updates
|
|
||||||
output = mx.array(output.tolist()) # Make a copy
|
|
||||||
output[idx, curr_len] = token_value
|
|
||||||
current_lengths = mx.array([
|
|
||||||
l + 1 if i == idx else l
|
|
||||||
for i, l in enumerate(current_lengths.tolist())
|
|
||||||
])
|
|
||||||
tokens_generated += 1
|
|
||||||
|
|
||||||
# Check end conditions
|
|
||||||
if token_value == tokenizer.eos_token_id:
|
|
||||||
not_finished = mx.array([
|
|
||||||
False if i == idx else nf
|
|
||||||
for i, nf in enumerate(not_finished.tolist())
|
|
||||||
])
|
|
||||||
elif curr_len >= end_len:
|
|
||||||
last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
|
|
||||||
if mx.array_equal(last_tokens, end_sequence):
|
|
||||||
not_finished = mx.array([
|
|
||||||
False if i == idx else nf
|
|
||||||
for i, nf in enumerate(not_finished.tolist())
|
|
||||||
])
|
|
||||||
|
|
||||||
if _ % 32 == 0:
|
|
||||||
mx.eval(output, current_lengths, not_finished)
|
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
generation_time = end_time - start_time
|
|
||||||
tokens_per_second = tokens_generated / generation_time
|
|
||||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
|
|
||||||
|
|
||||||
# Return only the valid part of each sequence
|
|
||||||
results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -214,11 +168,12 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
|||||||
seq_logits = logits[i, :seq_len]
|
seq_logits = logits[i, :seq_len]
|
||||||
seq_targets = targets[i, :seq_len]
|
seq_targets = targets[i, :seq_len]
|
||||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||||
|
|
||||||
token_log_probs = mx.take_along_axis(
|
token_log_probs = mx.take_along_axis(
|
||||||
log_probs,
|
log_probs,
|
||||||
seq_targets.reshape(seq_len, 1),
|
seq_targets.reshape(seq_len, 1), axis=-1
|
||||||
axis=-1
|
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
per_token_logps.append(token_log_probs)
|
per_token_logps.append(token_log_probs)
|
||||||
mx.eval(logits)
|
mx.eval(logits)
|
||||||
return per_token_logps
|
return per_token_logps
|
||||||
@ -254,7 +209,6 @@ def grpo_loss(
|
|||||||
prompt_tensor,
|
prompt_tensor,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
temperature,
|
|
||||||
group_size
|
group_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user