mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-30 05:11:12 +08:00
updates
This commit is contained in:
parent
c817743333
commit
3dfb21267b
@ -1,6 +1,5 @@
|
||||
from typing import List, Optional, Callable
|
||||
import re
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
|
||||
|
||||
@ -14,19 +13,30 @@ def r1_extract_xml_answer(text: str) -> str:
|
||||
print("r1_extract_xml_answer returned empty string")
|
||||
return ""
|
||||
|
||||
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
|
||||
def r1_int_reward_func(
|
||||
prompts: list, completions: list, answer: list, **kwargs
|
||||
) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
||||
|
||||
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
|
||||
def r1_accuracy_reward_func(
|
||||
prompts: list, completions: list, answer: list, **kwargs
|
||||
) -> list[float]:
|
||||
if not completions or not answer:
|
||||
return [0.0] * len(prompts)
|
||||
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
||||
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
||||
return [
|
||||
2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)
|
||||
]
|
||||
|
||||
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
|
||||
def r1_soft_format_reward_func(
|
||||
prompts: list, completions: list, answer: list, **kwargs
|
||||
) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
|
||||
@ -41,25 +51,35 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
|
||||
answer_start = completion.find("<answer>")
|
||||
answer_end = completion.find("</answer>")
|
||||
|
||||
if (reason_start != -1 and reason_end != -1 and
|
||||
answer_start != -1 and answer_end != -1 and
|
||||
reason_start < reason_end < answer_start < answer_end):
|
||||
reason_content = completion[reason_start+13:reason_end].strip()
|
||||
answer_content = completion[answer_start+8:answer_end].strip()
|
||||
if (
|
||||
reason_start != -1
|
||||
and reason_end != -1
|
||||
and answer_start != -1
|
||||
and answer_end != -1
|
||||
and reason_start < reason_end < answer_start < answer_end
|
||||
):
|
||||
reason_content = completion[reason_start + 13 : reason_end].strip()
|
||||
answer_content = completion[answer_start + 8 : answer_end].strip()
|
||||
if reason_content and answer_content:
|
||||
scores.append(0.5)
|
||||
continue
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
|
||||
def r1_strict_format_reward_func(
|
||||
prompts: list, completions: list, answer: list, **kwargs
|
||||
) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
pattern = r"<think> .*? </think><answer> .*? </answer>"
|
||||
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
||||
return [0.5 if match else 0.0 for match in matches]
|
||||
|
||||
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
||||
|
||||
def r1_count_xml(
|
||||
prompts: list, completions: list, answer: list, **kwargs
|
||||
) -> list[float]:
|
||||
if not completions:
|
||||
return [0.0] * len(prompts)
|
||||
scores = []
|
||||
|
@ -1,19 +1,28 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from typing import List, Optional, Tuple, Generator, Callable, Any
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
import time
|
||||
from typing import Any, Callable, Generator, List, Optional, Tuple
|
||||
|
||||
from mlx.utils import tree_flatten
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml,r1_extract_xml_answer, RewardFunctions
|
||||
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
||||
from ..utils import generate_step
|
||||
from ..models import cache
|
||||
from ..utils import generation_stream
|
||||
from .grpo_reward_functions import (
|
||||
RewardFunctions,
|
||||
r1_accuracy_reward_func,
|
||||
r1_count_xml,
|
||||
r1_extract_xml_answer,
|
||||
r1_int_reward_func,
|
||||
r1_soft_format_reward_func,
|
||||
r1_strict_format_reward_func,
|
||||
)
|
||||
from .trainer import TrainingArgs, TrainingCallback, average_gradients, grad_checkpoint
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOTrainingArgs(TrainingArgs):
|
||||
@ -21,9 +30,7 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
default=4,
|
||||
metadata={"help": "Number of responses per prompt."},
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.1, metadata={"help": "KL penalty coefficient."}
|
||||
)
|
||||
beta: float = field(default=0.1, metadata={"help": "KL penalty coefficient."})
|
||||
epsilon: float = field(
|
||||
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
|
||||
)
|
||||
@ -34,40 +41,142 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to reference model weights. If None, uses the same model."
|
||||
}
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
|
||||
}
|
||||
},
|
||||
)
|
||||
reward_weights: Optional[List[float]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
*,
|
||||
max_tokens: int = 256,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
max_kv_size: Optional[int] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
prefill_step_size: int = 512,
|
||||
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||
generator. Default: ``256``.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
None implies no cache quantization. Default: ``None``.
|
||||
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
|
||||
quantized_kv_start (int): Step to begin using a quantized KV cache.
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
|
||||
prompt tokens processed so far and the total number of prompt tokens.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
y = prompt
|
||||
tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(
|
||||
model,
|
||||
max_kv_size=max_kv_size,
|
||||
)
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
||||
|
||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
||||
|
||||
def _step(y):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if logits_processors:
|
||||
nonlocal tokens
|
||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
total_prompt_tokens = y.size
|
||||
prompt_processed_tokens = 0
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||
prompt_processed_tokens += prefill_step_size
|
||||
y = y[prefill_step_size:]
|
||||
mx.metal.clear_cache()
|
||||
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.eval(y, logprobs)
|
||||
n = 0
|
||||
while True:
|
||||
if n != max_tokens:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.eval(next_y, next_logprobs)
|
||||
if n == 0:
|
||||
mx.eval(y)
|
||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||
if n == max_tokens:
|
||||
break
|
||||
yield y.item(), logprobs
|
||||
if n % 256 == 0:
|
||||
mx.metal.clear_cache()
|
||||
y, logprobs = next_y, next_logprobs
|
||||
n += 1
|
||||
|
||||
|
||||
def generate_grpo(
|
||||
model: nn.Module,
|
||||
prompts,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
is_training=False,
|
||||
end_token: str = "</answer>",
|
||||
temperature: float = 0.8,
|
||||
batch_size: int = 1
|
||||
batch_size: int = 1,
|
||||
):
|
||||
# Store original training state
|
||||
was_training = model.training
|
||||
|
||||
# Set model to eval mode for generation
|
||||
model.eval()
|
||||
|
||||
try:
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if len(prompts.shape) == 1:
|
||||
prompts = prompts[None, :]
|
||||
if prompts.shape[1] == 0:
|
||||
@ -79,113 +188,84 @@ def generate_grpo(
|
||||
results = []
|
||||
mx.eval(expanded_prompts)
|
||||
|
||||
print(f"Setup time: {time.time() - start_time:.2f}s")
|
||||
print(f"Generating {total_samples} samples with max_tokens={max_tokens}")
|
||||
|
||||
total_tokens_generated = 0
|
||||
generation_start_time = time.time()
|
||||
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_samples, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_samples)
|
||||
batch_time = time.time()
|
||||
print(
|
||||
f"Starting batch {batch_start//batch_size + 1}/{(total_samples + batch_size - 1)//batch_size}: samples {batch_start}-{batch_end-1}"
|
||||
)
|
||||
|
||||
if is_training:
|
||||
# Training-specific generation logic
|
||||
batch_inputs = expanded_prompts[batch_start:batch_end]
|
||||
batch_tokens = [[] for _ in range(batch_end - batch_start)]
|
||||
prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)]
|
||||
# Custom sampler function that handles temperature
|
||||
def temp_sampler(logits):
|
||||
return mx.random.categorical(logits / temperature)
|
||||
|
||||
# Initial forward pass
|
||||
for i, prompt in enumerate(batch_inputs):
|
||||
logits = model(prompt[None], cache=prompt_caches[i])[:, -1]
|
||||
logits_temp = logits / temperature
|
||||
next_token = mx.random.categorical(logits_temp)
|
||||
token = next_token.item()
|
||||
batch_tokens[i].append(token)
|
||||
del logits, logits_temp, next_token
|
||||
# Batched processing
|
||||
for idx in range(batch_start, batch_end):
|
||||
sample_start_time = time.time()
|
||||
current_tokens = []
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
|
||||
mx.eval([tokens[-1] for tokens in batch_tokens])
|
||||
mx.metal.clear_cache()
|
||||
|
||||
active_indices = [i for i in range(len(batch_tokens)) if batch_tokens[i][-1] != tokenizer.eos_token_id]
|
||||
|
||||
# Generate remaining tokens
|
||||
for _ in range(max_tokens - 1):
|
||||
if not active_indices:
|
||||
break
|
||||
|
||||
next_active = []
|
||||
for idx in active_indices:
|
||||
current_input = mx.array([batch_tokens[idx][-1]])
|
||||
logits = model(current_input[None], cache=prompt_caches[idx])[:, -1]
|
||||
logits_temp = logits / temperature
|
||||
next_token = mx.random.categorical(logits_temp)
|
||||
token = next_token.item()
|
||||
batch_tokens[idx].append(token)
|
||||
|
||||
# Check for end conditions
|
||||
is_end = False
|
||||
if len(batch_tokens[idx]) >= len(end_sequence):
|
||||
test_sequence = batch_tokens[idx][-len(end_sequence):]
|
||||
is_end = mx.array_equal(mx.array(test_sequence), end_sequence)
|
||||
|
||||
if not (is_end or token == tokenizer.eos_token_id):
|
||||
next_active.append(idx)
|
||||
|
||||
del logits, logits_temp, next_token, current_input
|
||||
|
||||
mx.eval([tokens[-1] for tokens in batch_tokens])
|
||||
mx.metal.clear_cache()
|
||||
active_indices = next_active
|
||||
|
||||
# Clean up caches
|
||||
for pc in prompt_caches:
|
||||
del pc
|
||||
|
||||
# Process results
|
||||
for tokens in batch_tokens:
|
||||
if tokens:
|
||||
# Truncate at end token if present
|
||||
for i in range(len(tokens) - len(end_sequence) + 1):
|
||||
if mx.array_equal(
|
||||
mx.array(tokens[i:i+len(end_sequence)]),
|
||||
end_sequence
|
||||
):
|
||||
tokens = tokens[:i+len(end_sequence)]
|
||||
break
|
||||
|
||||
if tokens and tokens[-1] == tokenizer.eos_token_id:
|
||||
tokens = tokens[:-1]
|
||||
|
||||
if tokens:
|
||||
results.append(mx.array(tokens))
|
||||
|
||||
del batch_inputs, batch_tokens, prompt_caches
|
||||
mx.metal.clear_cache()
|
||||
else:
|
||||
# Non-training mode with batched processing
|
||||
for idx in range(batch_start, batch_end):
|
||||
current_tokens = []
|
||||
generator = generate_step(
|
||||
# The generate_step function yields one token at a time
|
||||
# We'll collect tokens until we hit max_tokens or a stopping condition
|
||||
for i, (token, _) in enumerate(
|
||||
generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.random.categorical(x / temperature)
|
||||
max_tokens=max_tokens, # This is the maximum number of steps
|
||||
sampler=temp_sampler,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
):
|
||||
print(token)
|
||||
current_tokens.append(token)
|
||||
|
||||
for token, _ in generator:
|
||||
test_sequence = current_tokens + [token]
|
||||
if (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(test_sequence[-len(end_sequence):]),
|
||||
end_sequence
|
||||
)):
|
||||
current_tokens.append(token)
|
||||
break
|
||||
# Check for end token
|
||||
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
||||
mx.array(current_tokens[-len(end_sequence) :]), end_sequence
|
||||
):
|
||||
break
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
current_tokens.append(token)
|
||||
# Check for EOS token
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
if current_tokens:
|
||||
results.append(mx.array(current_tokens))
|
||||
# Check if we've reached the maximum number of tokens
|
||||
if i >= max_tokens - 1:
|
||||
break
|
||||
|
||||
if current_tokens:
|
||||
results.append(mx.array(current_tokens))
|
||||
total_tokens_generated += len(current_tokens)
|
||||
|
||||
sample_time = time.time() - sample_start_time
|
||||
tokens_per_second = (
|
||||
len(current_tokens) / sample_time if sample_time > 0 else 0
|
||||
)
|
||||
print(
|
||||
f" Sample {idx}: Generated {len(current_tokens)} tokens in {sample_time:.2f}s ({tokens_per_second:.2f} tokens/sec)"
|
||||
)
|
||||
|
||||
batch_time = time.time() - batch_time
|
||||
print(f"Batch completed in {batch_time:.2f}s")
|
||||
mx.metal.clear_cache()
|
||||
|
||||
generation_time = time.time() - generation_start_time
|
||||
avg_tokens_per_second = (
|
||||
total_tokens_generated / generation_time if generation_time > 0 else 0
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generation complete: {total_tokens_generated} tokens in {generation_time:.2f}s"
|
||||
)
|
||||
print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec")
|
||||
|
||||
mx.eval(results)
|
||||
return results
|
||||
|
||||
@ -193,10 +273,6 @@ def generate_grpo(
|
||||
print(f"Generation error: {str(e)}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Don't restore training mode - let the caller handle it
|
||||
pass
|
||||
|
||||
|
||||
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
logits = model(inputs).astype(mx.float16)
|
||||
@ -209,8 +285,7 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
|
||||
seq_targets = targets[i, :seq_len]
|
||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
seq_targets.reshape(seq_len, 1), axis=-1
|
||||
log_probs, seq_targets.reshape(seq_len, 1), axis=-1
|
||||
).squeeze(-1)
|
||||
per_token_logps.append(token_log_probs)
|
||||
mx.eval(logits)
|
||||
@ -223,14 +298,14 @@ def grpo_loss(
|
||||
tokenizer,
|
||||
batch,
|
||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||
beta: float =0.1,
|
||||
beta: float = 0.1,
|
||||
group_size: int = 4,
|
||||
epsilon: float = 1e-4,
|
||||
max_tokens: int = 64,
|
||||
temperature: float = 0.8,
|
||||
reward_weights: Optional[List[float]] = None,
|
||||
is_validation: bool = False,
|
||||
batch_size: int = 1
|
||||
batch_size: int = 1,
|
||||
):
|
||||
prompt_tokens, _, prompt_text, answer_text = batch
|
||||
total_samples = len(prompt_tokens)
|
||||
@ -239,11 +314,17 @@ def grpo_loss(
|
||||
all_completion_texts = []
|
||||
batch_indices = [] # Keep track of which batch each completion belongs to
|
||||
|
||||
# Store original training state
|
||||
was_training = model.training
|
||||
|
||||
# Set model to eval mode for generation
|
||||
model.eval()
|
||||
|
||||
# Process in smaller batches
|
||||
for i in range(0, total_samples, batch_size):
|
||||
# Get actual batch size for this iteration (might be smaller for the last batch)
|
||||
current_batch_size = min(batch_size, total_samples - i)
|
||||
batch_prompts = prompt_tokens[i:i+current_batch_size]
|
||||
batch_prompts = prompt_tokens[i : i + current_batch_size]
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_prompt_len = max(len(p) for p in batch_prompts)
|
||||
@ -257,34 +338,23 @@ def grpo_loss(
|
||||
prompt_tensor = mx.array(padded_prompts)
|
||||
|
||||
try:
|
||||
if is_validation:
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
temperature=temperature,
|
||||
batch_size=current_batch_size
|
||||
)
|
||||
model.train()
|
||||
else:
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
is_training=True,
|
||||
temperature=temperature,
|
||||
batch_size=current_batch_size
|
||||
)
|
||||
completions = generate_grpo(
|
||||
model,
|
||||
prompt_tensor,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
temperature=temperature,
|
||||
batch_size=current_batch_size,
|
||||
)
|
||||
|
||||
if completions is not None:
|
||||
for j, completion_ids in enumerate(completions):
|
||||
# Calculate which prompt this completion belongs to
|
||||
prompt_idx = i + (j // group_size)
|
||||
if prompt_idx < total_samples: # Make sure we don't go out of bounds
|
||||
if (
|
||||
prompt_idx < total_samples
|
||||
): # Make sure we don't go out of bounds
|
||||
batch_indices.append(prompt_idx)
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
@ -294,12 +364,19 @@ def grpo_loss(
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
# Restore original training state if we're not in validation mode
|
||||
if not is_validation and was_training:
|
||||
model.train()
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# If we didn't generate any completions, return early
|
||||
if not all_completions:
|
||||
raise ValueError("No completions were generated. Please check your model and inputs.")
|
||||
raise ValueError(
|
||||
"No completions were generated. Please check your model and inputs."
|
||||
)
|
||||
|
||||
# The rest of the function remains the same
|
||||
# Create expanded prompts and answers based on actual generated completions
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
@ -341,7 +418,9 @@ def grpo_loss(
|
||||
if padding_length > 0:
|
||||
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
|
||||
padded_ids = mx.concatenate([completion_ids, padding])
|
||||
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
|
||||
mask = mx.concatenate(
|
||||
[mx.ones_like(completion_ids), mx.zeros_like(padding)]
|
||||
)
|
||||
else:
|
||||
padded_ids = completion_ids
|
||||
mask = mx.ones_like(completion_ids)
|
||||
@ -381,11 +460,13 @@ def grpo_loss(
|
||||
|
||||
# Collect rewards from each function separately
|
||||
for reward_func in reward_funcs:
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=expanded_prompts,
|
||||
completions=all_completion_texts,
|
||||
answer=expanded_answers
|
||||
))
|
||||
func_rewards = mx.array(
|
||||
reward_func(
|
||||
prompts=expanded_prompts,
|
||||
completions=all_completion_texts,
|
||||
answer=expanded_answers,
|
||||
)
|
||||
)
|
||||
all_func_rewards.append(func_rewards)
|
||||
|
||||
# Stack rewards to shape (num_samples, num_funcs)
|
||||
@ -422,25 +503,39 @@ def grpo_loss(
|
||||
std_reward = mx.std(prompt_rewards)
|
||||
|
||||
# Find indices for this prompt
|
||||
indices = [j for j, idx in enumerate(batch_indices) if idx == unique_prompt_indices[i]]
|
||||
indices = [
|
||||
j
|
||||
for j, idx in enumerate(batch_indices)
|
||||
if idx == unique_prompt_indices[i]
|
||||
]
|
||||
for j, idx in enumerate(indices):
|
||||
advantages[idx] = (prompt_rewards[j] - mean_reward) / (std_reward + epsilon)
|
||||
advantages[idx] = (prompt_rewards[j] - mean_reward) / (
|
||||
std_reward + epsilon
|
||||
)
|
||||
else:
|
||||
# If only one sample, advantage is 0
|
||||
idx = batch_indices.index(unique_prompt_indices[i])
|
||||
advantages[idx] = 0.0
|
||||
|
||||
# Compute KL divergence using Schulman's approximator
|
||||
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
|
||||
kl_div = (
|
||||
mx.exp(ref_token_log_probs - token_log_probs)
|
||||
- (ref_token_log_probs - token_log_probs)
|
||||
- 1
|
||||
)
|
||||
|
||||
# Create mask for valid tokens
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Compute policy ratio
|
||||
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
|
||||
policy_ratio = mx.exp(
|
||||
mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))
|
||||
)
|
||||
|
||||
# Compute per-token loss
|
||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||
per_token_loss = -(
|
||||
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
|
||||
)
|
||||
|
||||
# Average over tokens
|
||||
sequence_sums = per_token_loss.sum(axis=1)
|
||||
@ -454,25 +549,33 @@ def grpo_loss(
|
||||
reward_metrics = {}
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_name = reward_func.__name__
|
||||
func_rewards = mx.array(reward_func(
|
||||
prompts=expanded_prompts,
|
||||
completions=all_completion_texts,
|
||||
answer=expanded_answers
|
||||
))
|
||||
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
||||
func_rewards = mx.array(
|
||||
reward_func(
|
||||
prompts=expanded_prompts,
|
||||
completions=all_completion_texts,
|
||||
answer=expanded_answers,
|
||||
)
|
||||
)
|
||||
reward_metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
|
||||
reward_metrics[f"{func_name}_std"] = mx.std(func_rewards)
|
||||
|
||||
|
||||
grouped_rewards_mean = mx.array([mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt])
|
||||
grouped_rewards_std = mx.array([mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) for rewards in rewards_by_prompt])
|
||||
grouped_rewards_mean = mx.array(
|
||||
[mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt]
|
||||
)
|
||||
grouped_rewards_std = mx.array(
|
||||
[
|
||||
mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1)
|
||||
for rewards in rewards_by_prompt
|
||||
]
|
||||
)
|
||||
|
||||
metrics = {
|
||||
'total_rewards_mean': mx.mean(rewards),
|
||||
'total_rewards_std': mx.std(rewards),
|
||||
'grouped_rewards_mean': mx.mean(grouped_rewards_mean),
|
||||
'grouped_rewards_std': mx.mean(grouped_rewards_std),
|
||||
'kl': mean_kl,
|
||||
**reward_metrics
|
||||
"total_rewards_mean": mx.mean(rewards),
|
||||
"total_rewards_std": mx.std(rewards),
|
||||
"grouped_rewards_mean": mx.mean(grouped_rewards_mean),
|
||||
"grouped_rewards_std": mx.mean(grouped_rewards_std),
|
||||
"kl": mean_kl,
|
||||
**reward_metrics,
|
||||
}
|
||||
|
||||
if is_validation and all_completion_texts:
|
||||
@ -483,26 +586,28 @@ def grpo_loss(
|
||||
|
||||
if last_prompt_idx < len(prompt_text):
|
||||
print(f"\n📋 Raw Prompt:\n{prompt_text[last_prompt_idx]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
print("\n" + "=" * 10 + "\n")
|
||||
|
||||
# Get the actual tokenized prompt that was fed to the model
|
||||
if last_prompt_idx < len(prompt_tokens):
|
||||
actual_prompt = tokenizer.decode(prompt_tokens[last_prompt_idx])
|
||||
print(f"\n🔄 Model Input:\n{actual_prompt}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
print("\n" + "=" * 10 + "\n")
|
||||
|
||||
print(f"\n📝 Generation:\n{all_completion_texts[-1]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
print("\n" + "=" * 10 + "\n")
|
||||
|
||||
# Make sure we have a valid index for answer_text
|
||||
if last_prompt_idx < len(answer_text):
|
||||
print(f"\n✅ Answer:\n{answer_text[last_prompt_idx]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
print("\n" + "=" * 10 + "\n")
|
||||
|
||||
# Only try to extract if r1_extract_xml_answer is defined
|
||||
if 'r1_extract_xml_answer' in globals():
|
||||
print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}")
|
||||
print("\n" + "="*35 + "\n")
|
||||
if "r1_extract_xml_answer" in globals():
|
||||
print(
|
||||
f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}"
|
||||
)
|
||||
print("\n" + "=" * 35 + "\n")
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
@ -511,7 +616,9 @@ def grpo_loss(
|
||||
|
||||
def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||
raise ValueError(
|
||||
"Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples"
|
||||
)
|
||||
|
||||
def length_key(i):
|
||||
return len(dataset[i][0]) + len(dataset[i][1])
|
||||
@ -534,7 +641,8 @@ def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
|
||||
|
||||
while True:
|
||||
indices = (
|
||||
np.random.permutation(list(batch_index_generator())) if train
|
||||
np.random.permutation(list(batch_index_generator()))
|
||||
if train
|
||||
else batch_index_generator()
|
||||
)
|
||||
|
||||
@ -576,10 +684,10 @@ def evaluate_grpo(
|
||||
r1_int_reward_func,
|
||||
r1_strict_format_reward_func,
|
||||
r1_soft_format_reward_func,
|
||||
r1_count_xml
|
||||
r1_count_xml,
|
||||
],
|
||||
loss_fn: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches
|
||||
iterate_batches: callable = iterate_grpo_batches,
|
||||
):
|
||||
all_losses = 0
|
||||
ntokens = 0
|
||||
@ -606,7 +714,7 @@ def evaluate_grpo(
|
||||
ref_model=ref_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
is_validation=True
|
||||
is_validation=True,
|
||||
)
|
||||
|
||||
all_losses += losses * toks
|
||||
@ -642,14 +750,16 @@ def train_grpo(
|
||||
r1_int_reward_func,
|
||||
r1_strict_format_reward_func,
|
||||
r1_soft_format_reward_func,
|
||||
r1_count_xml
|
||||
r1_count_xml,
|
||||
],
|
||||
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
||||
loss_fn: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_grpo_batches,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
|
||||
print(
|
||||
f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}"
|
||||
)
|
||||
world = mx.distributed.init()
|
||||
world_size = world.size()
|
||||
rank = world.rank()
|
||||
@ -672,7 +782,7 @@ def train_grpo(
|
||||
epsilon=args.epsilon,
|
||||
ref_model=ref_model,
|
||||
max_tokens=args.max_completion_length,
|
||||
temperature=args.temperature
|
||||
temperature=args.temperature,
|
||||
)
|
||||
|
||||
grad = average_gradients(grad)
|
||||
@ -688,16 +798,16 @@ def train_grpo(
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
accumulated_metrics = {
|
||||
'total_rewards_mean': 0,
|
||||
'total_rewards_std': 0,
|
||||
'grouped_rewards_mean': 0,
|
||||
'grouped_rewards_std': 0,
|
||||
'kl': 0
|
||||
"total_rewards_mean": 0,
|
||||
"total_rewards_std": 0,
|
||||
"grouped_rewards_mean": 0,
|
||||
"grouped_rewards_std": 0,
|
||||
"kl": 0,
|
||||
}
|
||||
for reward_func in reward_funcs:
|
||||
func_name = reward_func.__name__
|
||||
accumulated_metrics[f'{func_name}_mean'] = 0
|
||||
accumulated_metrics[f'{func_name}_std'] = 0
|
||||
accumulated_metrics[f"{func_name}_mean"] = 0
|
||||
accumulated_metrics[f"{func_name}_std"] = 0
|
||||
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
@ -746,18 +856,19 @@ def train_grpo(
|
||||
)
|
||||
|
||||
print(
|
||||
f"Iter {it}: {val_metrics_str}, "
|
||||
f"Val took {val_time:.3f}s",
|
||||
f"Iter {it}: {val_metrics_str}, " f"Val took {val_time:.3f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_val_loss_report({
|
||||
"iteration": it,
|
||||
"val_loss": val_loss,
|
||||
**{f"val_{k}": v for k, v in val_metrics.items()},
|
||||
"val_time": val_time,
|
||||
})
|
||||
training_callback.on_val_loss_report(
|
||||
{
|
||||
"iteration": it,
|
||||
"val_loss": val_loss,
|
||||
**{f"val_{k}": v for k, v in val_metrics.items()},
|
||||
"val_time": val_time,
|
||||
}
|
||||
)
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
@ -776,7 +887,9 @@ def train_grpo(
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
|
||||
avg_metrics = {
|
||||
k: v / (steps * world_size) for k, v in accumulated_metrics.items()
|
||||
}
|
||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
@ -811,16 +924,18 @@ def train_grpo(
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_train_loss_report({
|
||||
"iteration": it,
|
||||
"train_loss": train_loss,
|
||||
**{f"train_{k}": v for k, v in avg_metrics.items()},
|
||||
"learning_rate": learning_rate,
|
||||
"iterations_per_second": it_sec,
|
||||
"tokens_per_second": tokens_sec,
|
||||
"trained_tokens": trained_tokens,
|
||||
"peak_memory": peak_mem,
|
||||
})
|
||||
training_callback.on_train_loss_report(
|
||||
{
|
||||
"iteration": it,
|
||||
"train_loss": train_loss,
|
||||
**{f"train_{k}": v for k, v in avg_metrics.items()},
|
||||
"learning_rate": learning_rate,
|
||||
"iterations_per_second": it_sec,
|
||||
"tokens_per_second": tokens_sec,
|
||||
"trained_tokens": trained_tokens,
|
||||
"peak_memory": peak_mem,
|
||||
}
|
||||
)
|
||||
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
|
Loading…
Reference in New Issue
Block a user