mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-23 12:41:17 +08:00
Huge speed improvement in validation mode.
This commit is contained in:
parent
2f20107d9b
commit
6086137131
@ -30,7 +30,7 @@ class GRPODataset:
|
|||||||
prompt_tokens = tokenizer.apply_chat_template(
|
prompt_tokens = tokenizer.apply_chat_template(
|
||||||
[
|
[
|
||||||
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
|
||||||
{'role': 'user', 'content': prompt_str}
|
{'role': 'user', 'content': prompt_str}
|
||||||
],
|
],
|
||||||
add_generation_prompt=True
|
add_generation_prompt=True
|
||||||
)
|
)
|
||||||
@ -44,11 +44,9 @@ class GRPODataset:
|
|||||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Returns the number of examples in the dataset."""
|
|
||||||
return len(self._data)
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import mlx.nn as nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||||
|
from ..utils import generate_step
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GRPOTrainingArgs(TrainingArgs):
|
class GRPOTrainingArgs(TrainingArgs):
|
||||||
@ -61,7 +61,6 @@ def r1_extract_xml_answer(text: str) -> str:
|
|||||||
print("r1_extract_xml_answer returned empty string")
|
print("r1_extract_xml_answer returned empty string")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||||
if not completions:
|
if not completions:
|
||||||
return [0.0] * len(prompts)
|
return [0.0] * len(prompts)
|
||||||
@ -111,92 +110,47 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
||||||
scores.append(max(0.0, count))
|
scores.append(max(0.0, count))
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size):
|
||||||
def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size):
|
|
||||||
if len(prompts.shape) == 1:
|
if len(prompts.shape) == 1:
|
||||||
prompts = prompts[None, :]
|
prompts = prompts[None, :]
|
||||||
if prompts.shape[1] == 0:
|
if prompts.shape[1] == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
model.eval()
|
||||||
tokens_generated = 0
|
|
||||||
batch_size = prompts.shape[0] * group_size
|
batch_size = prompts.shape[0] * group_size
|
||||||
|
|
||||||
# Repeat each prompt group_size times
|
|
||||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||||
|
|
||||||
end_sequence = mx.array(tokenizer.encode("</answer>"))
|
results = []
|
||||||
end_len = len(end_sequence)
|
tokens_generated = 0
|
||||||
initial_length = prompts.shape[1]
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# Initialize output tensor for all sequences
|
|
||||||
output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32)
|
|
||||||
output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1)
|
|
||||||
current_lengths = mx.array([initial_length] * batch_size)
|
|
||||||
|
|
||||||
temp_factor = 1/temperature if temperature > 0 else float('inf')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
not_finished = mx.ones((batch_size,), dtype=mx.bool_)
|
for idx in range(batch_size):
|
||||||
|
current_tokens = []
|
||||||
for _ in range(max_tokens):
|
generator = generate_step(
|
||||||
# Check if all sequences are finished
|
expanded_prompts[idx],
|
||||||
if not mx.sum(not_finished).item():
|
model,
|
||||||
break
|
max_tokens=max_tokens,
|
||||||
|
sampler=lambda x: mx.argmax(x, axis=-1)
|
||||||
# Get model outputs for all sequences
|
)
|
||||||
max_len = mx.max(current_lengths).item()
|
|
||||||
batch_inputs = output[:, :max_len]
|
|
||||||
logits = model(batch_inputs)[:, -1]
|
|
||||||
|
|
||||||
# Apply mask to logits
|
# Collect all tokens first
|
||||||
logits = logits * mx.expand_dims(not_finished, -1)
|
for tokens, _ in generator:
|
||||||
|
current_tokens.append(tokens)
|
||||||
|
tokens_generated += 1
|
||||||
|
if tokens == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
# Sample next tokens
|
# Convert to array after collection
|
||||||
logits *= temp_factor
|
results.append(mx.array(current_tokens))
|
||||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
mx.metal.clear_cache()
|
||||||
next_tokens = mx.random.categorical(logprobs)
|
|
||||||
|
|
||||||
# Update outputs for active sequences
|
|
||||||
for idx in range(batch_size):
|
|
||||||
if not_finished[idx].item():
|
|
||||||
curr_len = current_lengths[idx].item()
|
|
||||||
token_value = next_tokens[idx].item()
|
|
||||||
|
|
||||||
# Create new arrays with updates
|
|
||||||
output = mx.array(output.tolist()) # Make a copy
|
|
||||||
output[idx, curr_len] = token_value
|
|
||||||
current_lengths = mx.array([
|
|
||||||
l + 1 if i == idx else l
|
|
||||||
for i, l in enumerate(current_lengths.tolist())
|
|
||||||
])
|
|
||||||
tokens_generated += 1
|
|
||||||
|
|
||||||
# Check end conditions
|
|
||||||
if token_value == tokenizer.eos_token_id:
|
|
||||||
not_finished = mx.array([
|
|
||||||
False if i == idx else nf
|
|
||||||
for i, nf in enumerate(not_finished.tolist())
|
|
||||||
])
|
|
||||||
elif curr_len >= end_len:
|
|
||||||
last_tokens = output[idx, curr_len-end_len+1:curr_len+1]
|
|
||||||
if mx.array_equal(last_tokens, end_sequence):
|
|
||||||
not_finished = mx.array([
|
|
||||||
False if i == idx else nf
|
|
||||||
for i, nf in enumerate(not_finished.tolist())
|
|
||||||
])
|
|
||||||
|
|
||||||
if _ % 32 == 0:
|
|
||||||
mx.eval(output, current_lengths, not_finished)
|
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
# Final evaluation of all results
|
||||||
generation_time = end_time - start_time
|
mx.eval(results)
|
||||||
tokens_per_second = tokens_generated / generation_time
|
generation_time = time.perf_counter() - start_time
|
||||||
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")
|
print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)")
|
||||||
|
|
||||||
# Return only the valid part of each sequence
|
|
||||||
results = [output[i, :current_lengths[i].item()] for i in range(batch_size)]
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -214,13 +168,14 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
|||||||
seq_logits = logits[i, :seq_len]
|
seq_logits = logits[i, :seq_len]
|
||||||
seq_targets = targets[i, :seq_len]
|
seq_targets = targets[i, :seq_len]
|
||||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||||
|
|
||||||
token_log_probs = mx.take_along_axis(
|
token_log_probs = mx.take_along_axis(
|
||||||
log_probs,
|
log_probs,
|
||||||
seq_targets.reshape(seq_len, 1),
|
seq_targets.reshape(seq_len, 1), axis=-1
|
||||||
axis=-1
|
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
per_token_logps.append(token_log_probs)
|
per_token_logps.append(token_log_probs)
|
||||||
mx.eval(logits)
|
mx.eval(logits)
|
||||||
return per_token_logps
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
@ -253,8 +208,7 @@ def grpo_loss(
|
|||||||
model,
|
model,
|
||||||
prompt_tensor,
|
prompt_tensor,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
temperature,
|
|
||||||
group_size
|
group_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user