2025-01-31 06:55:34 +08:00
|
|
|
# Copyright © 2024 Apple Inc.
|
|
|
|
|
2025-02-25 03:49:11 +08:00
|
|
|
from typing import List, Optional, Tuple, Generator, Callable, Any
|
2025-01-31 06:55:34 +08:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from pathlib import Path
|
2025-02-12 18:07:53 +08:00
|
|
|
import time
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-12 18:07:53 +08:00
|
|
|
from mlx.utils import tree_flatten
|
2025-01-31 06:55:34 +08:00
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
import numpy as np
|
|
|
|
|
2025-02-25 05:20:07 +08:00
|
|
|
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml,r1_extract_xml_answer, RewardFunctions
|
2025-02-12 18:07:53 +08:00
|
|
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
|
2025-02-22 05:08:49 +08:00
|
|
|
from ..utils import generate_step
|
2025-02-22 09:03:01 +08:00
|
|
|
from ..models import cache
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class GRPOTrainingArgs(TrainingArgs):
|
|
|
|
group_size: int = field(
|
|
|
|
default=4,
|
2025-02-01 04:10:44 +08:00
|
|
|
metadata={"help": "Number of responses per prompt."},
|
2025-01-31 06:55:34 +08:00
|
|
|
)
|
|
|
|
beta: float = field(
|
|
|
|
default=0.1, metadata={"help": "KL penalty coefficient."}
|
|
|
|
)
|
|
|
|
epsilon: float = field(
|
|
|
|
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
|
|
|
|
)
|
2025-02-04 16:18:45 +08:00
|
|
|
max_completion_length: int = field(
|
|
|
|
default=512, metadata={"help": "Number of Generations."}
|
|
|
|
)
|
2025-01-31 06:55:34 +08:00
|
|
|
reference_model_path: str = field(
|
|
|
|
default=None,
|
|
|
|
metadata={
|
|
|
|
"help": "Path to reference model weights. If None, uses the same model."
|
|
|
|
}
|
|
|
|
)
|
2025-02-15 22:29:22 +08:00
|
|
|
temperature: float = field(
|
|
|
|
default=1.0,
|
|
|
|
metadata={
|
|
|
|
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
|
|
|
|
}
|
|
|
|
)
|
|
|
|
reward_weights: Optional[List[float]] = field(
|
|
|
|
default=None,
|
|
|
|
metadata={
|
|
|
|
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
|
|
|
|
}
|
|
|
|
)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-05 18:30:21 +08:00
|
|
|
|
2025-02-25 03:49:11 +08:00
|
|
|
def generate_grpo(
|
2025-02-28 23:02:40 +08:00
|
|
|
model: nn.Module,
|
|
|
|
prompts,
|
|
|
|
max_tokens,
|
|
|
|
tokenizer,
|
|
|
|
group_size,
|
|
|
|
is_training=False,
|
|
|
|
end_token: str = "</answer>",
|
|
|
|
temperature: float = 0.8,
|
|
|
|
batch_size: int = 1
|
|
|
|
):
|
2025-02-21 23:02:27 +08:00
|
|
|
if len(prompts.shape) == 1:
|
|
|
|
prompts = prompts[None, :]
|
|
|
|
if prompts.shape[1] == 0:
|
2025-02-05 15:44:06 +08:00
|
|
|
return None
|
2025-02-28 23:02:40 +08:00
|
|
|
|
|
|
|
total_samples = prompts.shape[0] * group_size
|
2025-02-21 23:02:27 +08:00
|
|
|
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
2025-02-22 09:03:01 +08:00
|
|
|
end_sequence = mx.array(tokenizer.encode(end_token))
|
2025-02-22 05:08:49 +08:00
|
|
|
results = []
|
2025-02-23 00:21:08 +08:00
|
|
|
mx.eval(expanded_prompts)
|
2025-02-28 23:02:40 +08:00
|
|
|
|
2025-02-22 07:21:47 +08:00
|
|
|
try:
|
2025-02-28 23:02:40 +08:00
|
|
|
# Process in batches
|
|
|
|
for batch_start in range(0, total_samples, batch_size):
|
|
|
|
batch_end = min(batch_start + batch_size, total_samples)
|
|
|
|
|
2025-02-22 05:42:15 +08:00
|
|
|
if is_training:
|
2025-02-28 23:02:40 +08:00
|
|
|
# Training mode with batched processing
|
|
|
|
batch_inputs = expanded_prompts[batch_start:batch_end]
|
|
|
|
prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)]
|
|
|
|
|
|
|
|
# Initial forward pass for all prompts in batch
|
|
|
|
batch_logits = []
|
|
|
|
for i, prompt in enumerate(batch_inputs):
|
|
|
|
logits = model(prompt[None], cache=prompt_caches[i])[:, -1]
|
|
|
|
batch_logits.append(logits)
|
|
|
|
mx.eval(batch_logits, prompt_caches)
|
|
|
|
|
|
|
|
# Track tokens for each sequence in the batch
|
|
|
|
batch_tokens = [[] for _ in range(batch_end - batch_start)]
|
2025-03-01 05:07:19 +08:00
|
|
|
|
|
|
|
# Initial token generation for all sequences in batch
|
|
|
|
for i in range(len(batch_logits)):
|
|
|
|
logits_temp = batch_logits[i] / temperature
|
|
|
|
next_token = mx.random.categorical(logits_temp)
|
|
|
|
token = next_token.item()
|
|
|
|
mx.eval(logits_temp, next_token, token)
|
|
|
|
batch_tokens[i].append(token)
|
|
|
|
|
|
|
|
# Check if this token already completes the sequence
|
|
|
|
if token == tokenizer.eos_token_id:
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
# Set up for next token
|
|
|
|
current_input = mx.array([token])
|
|
|
|
batch_logits[i] = model(current_input[None], cache=prompt_caches[i])[:, -1]
|
|
|
|
|
|
|
|
mx.eval(batch_logits)
|
|
|
|
active_indices = [i for i, tokens in enumerate(batch_tokens) if tokens[-1] != tokenizer.eos_token_id and len(tokens) < max_tokens]
|
2025-02-28 23:02:40 +08:00
|
|
|
|
|
|
|
# Generate tokens until all sequences are complete
|
|
|
|
while active_indices and max(len(tokens) for tokens in batch_tokens) < max_tokens:
|
|
|
|
next_active = []
|
|
|
|
for idx in active_indices:
|
|
|
|
logits_temp = batch_logits[idx] / temperature
|
|
|
|
next_token = mx.random.categorical(logits_temp)
|
|
|
|
token = next_token.item()
|
2025-03-01 05:07:19 +08:00
|
|
|
mx.eval(logits_temp, next_token, token)
|
2025-02-28 23:02:40 +08:00
|
|
|
batch_tokens[idx].append(token)
|
|
|
|
|
2025-03-01 05:07:19 +08:00
|
|
|
# Check for end sequence
|
|
|
|
if len(batch_tokens[idx]) >= len(end_sequence):
|
|
|
|
test_sequence = batch_tokens[idx][-len(end_sequence):]
|
|
|
|
is_end = mx.array_equal(
|
|
|
|
mx.array(test_sequence),
|
|
|
|
end_sequence
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
is_end = False
|
|
|
|
|
2025-02-28 23:02:40 +08:00
|
|
|
if is_end or token == tokenizer.eos_token_id or len(batch_tokens[idx]) >= max_tokens:
|
|
|
|
# This sequence is done
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
# Continue with this sequence
|
|
|
|
next_active.append(idx)
|
|
|
|
current_input = mx.array([token])
|
|
|
|
batch_logits[idx] = model(current_input[None], cache=prompt_caches[idx])[:, -1]
|
|
|
|
|
|
|
|
mx.eval([batch_logits[idx] for idx in next_active])
|
|
|
|
active_indices = next_active
|
2025-03-01 05:07:19 +08:00
|
|
|
|
|
|
|
# Clear caches after processing this batch
|
|
|
|
for pc in prompt_caches:
|
|
|
|
del pc
|
2025-02-28 23:02:40 +08:00
|
|
|
|
|
|
|
# Add batch results to overall results
|
|
|
|
for tokens in batch_tokens:
|
|
|
|
if tokens:
|
2025-03-01 05:07:19 +08:00
|
|
|
# Filter out any special tokens that might appear after the end token
|
|
|
|
if len(tokens) >= len(end_sequence):
|
|
|
|
for i in range(len(tokens) - len(end_sequence) + 1):
|
|
|
|
if mx.array_equal(
|
|
|
|
mx.array(tokens[i:i+len(end_sequence)]),
|
|
|
|
end_sequence
|
|
|
|
):
|
|
|
|
tokens = tokens[:i+len(end_sequence)]
|
|
|
|
break
|
|
|
|
|
|
|
|
# Filter out EOS token if it's the last token
|
|
|
|
if tokens and tokens[-1] == tokenizer.eos_token_id:
|
|
|
|
tokens = tokens[:-1]
|
|
|
|
|
|
|
|
# Only add non-empty token lists
|
|
|
|
if tokens:
|
|
|
|
results.append(mx.array(tokens))
|
2025-02-22 05:42:15 +08:00
|
|
|
else:
|
2025-02-28 23:02:40 +08:00
|
|
|
# Non-training mode with batched processing
|
|
|
|
for idx in range(batch_start, batch_end):
|
|
|
|
current_tokens = []
|
|
|
|
generator = generate_step(
|
|
|
|
expanded_prompts[idx],
|
|
|
|
model,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
sampler=lambda x: mx.random.categorical(x / temperature)
|
|
|
|
)
|
|
|
|
|
|
|
|
for token, _ in generator:
|
|
|
|
test_sequence = current_tokens + [token]
|
|
|
|
if (len(test_sequence) >= len(end_sequence) and
|
|
|
|
mx.array_equal(
|
|
|
|
mx.array(test_sequence[-len(end_sequence):]),
|
|
|
|
end_sequence
|
|
|
|
)):
|
|
|
|
current_tokens.append(token)
|
|
|
|
break
|
|
|
|
|
|
|
|
if token == tokenizer.eos_token_id:
|
|
|
|
break
|
2025-02-25 03:49:11 +08:00
|
|
|
current_tokens.append(token)
|
2025-02-28 23:02:40 +08:00
|
|
|
|
|
|
|
if current_tokens:
|
|
|
|
results.append(mx.array(current_tokens))
|
|
|
|
|
2025-02-22 07:21:47 +08:00
|
|
|
mx.metal.clear_cache()
|
|
|
|
mx.eval(results)
|
|
|
|
return results
|
2025-02-22 05:42:15 +08:00
|
|
|
|
2025-02-22 07:21:47 +08:00
|
|
|
except Exception as e:
|
|
|
|
print(f"Generation error: {str(e)}")
|
2025-02-05 15:44:06 +08:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
2025-02-12 15:57:26 +08:00
|
|
|
def get_per_token_logps(model: nn.Module, inputs, lengths):
|
2025-02-05 16:48:00 +08:00
|
|
|
logits = model(inputs).astype(mx.float16)
|
|
|
|
logits = logits[:, :-1, :]
|
|
|
|
targets = inputs[:, 1:]
|
2025-02-04 04:57:26 +08:00
|
|
|
per_token_logps = []
|
|
|
|
for i in range(logits.shape[0]):
|
2025-02-05 16:48:00 +08:00
|
|
|
seq_len = int(lengths[i]) - 1
|
|
|
|
seq_logits = logits[i, :seq_len]
|
|
|
|
seq_targets = targets[i, :seq_len]
|
|
|
|
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
2025-02-04 04:57:26 +08:00
|
|
|
token_log_probs = mx.take_along_axis(
|
|
|
|
log_probs,
|
2025-02-22 05:08:49 +08:00
|
|
|
seq_targets.reshape(seq_len, 1), axis=-1
|
2025-02-05 16:48:00 +08:00
|
|
|
).squeeze(-1)
|
2025-02-04 04:57:26 +08:00
|
|
|
per_token_logps.append(token_log_probs)
|
2025-02-22 05:08:49 +08:00
|
|
|
mx.eval(logits)
|
2025-02-04 04:57:26 +08:00
|
|
|
return per_token_logps
|
|
|
|
|
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
def grpo_loss(
|
2025-02-12 18:07:53 +08:00
|
|
|
model,
|
|
|
|
ref_model,
|
2025-02-03 19:05:29 +08:00
|
|
|
tokenizer,
|
2025-02-12 18:07:53 +08:00
|
|
|
batch,
|
2025-02-28 23:02:40 +08:00
|
|
|
reward_funcs: Optional[List[RewardFunctions]] = None,
|
|
|
|
beta: float =0.1,
|
|
|
|
group_size: int = 4,
|
|
|
|
epsilon: float = 1e-4,
|
|
|
|
max_tokens: int = 64,
|
|
|
|
temperature: float = 0.8,
|
|
|
|
reward_weights: Optional[List[float]] = None,
|
|
|
|
is_validation: bool = False,
|
|
|
|
batch_size: int = 1
|
2025-02-03 19:05:29 +08:00
|
|
|
):
|
2025-02-15 22:29:22 +08:00
|
|
|
prompt_tokens, _, prompt_text, answer_text = batch
|
2025-02-28 23:02:40 +08:00
|
|
|
total_samples = len(prompt_tokens)
|
2025-02-12 18:07:53 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
all_completions = []
|
2025-02-03 19:05:29 +08:00
|
|
|
all_completion_texts = []
|
2025-02-28 23:02:40 +08:00
|
|
|
batch_indices = [] # Keep track of which batch each completion belongs to
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-28 23:02:40 +08:00
|
|
|
# Process in smaller batches
|
|
|
|
for i in range(0, total_samples, batch_size):
|
|
|
|
# Get actual batch size for this iteration (might be smaller for the last batch)
|
|
|
|
current_batch_size = min(batch_size, total_samples - i)
|
|
|
|
batch_prompts = prompt_tokens[i:i+current_batch_size]
|
|
|
|
|
|
|
|
# Pad sequences to the same length
|
|
|
|
max_prompt_len = max(len(p) for p in batch_prompts)
|
|
|
|
padded_prompts = []
|
|
|
|
|
|
|
|
for prompt in batch_prompts:
|
|
|
|
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
|
|
|
|
padded_prompts.append(prompt + padding)
|
|
|
|
|
|
|
|
# Convert to tensor
|
|
|
|
prompt_tensor = mx.array(padded_prompts)
|
2025-02-21 23:02:27 +08:00
|
|
|
|
|
|
|
try:
|
2025-02-22 07:21:47 +08:00
|
|
|
if is_validation:
|
|
|
|
completions = generate_grpo(
|
2025-02-25 05:20:07 +08:00
|
|
|
model,
|
|
|
|
prompt_tensor,
|
|
|
|
max_tokens,
|
2025-02-22 07:21:47 +08:00
|
|
|
tokenizer,
|
2025-02-25 05:20:07 +08:00
|
|
|
group_size,
|
2025-02-28 23:02:40 +08:00
|
|
|
temperature=temperature,
|
|
|
|
batch_size=current_batch_size
|
2025-02-22 07:21:47 +08:00
|
|
|
)
|
2025-02-25 05:20:07 +08:00
|
|
|
model.train()
|
2025-02-22 07:21:47 +08:00
|
|
|
else:
|
|
|
|
completions = generate_grpo(
|
2025-02-25 05:20:07 +08:00
|
|
|
model,
|
|
|
|
prompt_tensor,
|
|
|
|
max_tokens,
|
|
|
|
tokenizer,
|
2025-02-22 07:21:47 +08:00
|
|
|
group_size,
|
2025-02-22 09:34:56 +08:00
|
|
|
is_training=True,
|
2025-02-28 23:02:40 +08:00
|
|
|
temperature=temperature,
|
|
|
|
batch_size=current_batch_size
|
2025-02-22 07:21:47 +08:00
|
|
|
)
|
2025-02-28 23:02:40 +08:00
|
|
|
|
2025-02-21 23:02:27 +08:00
|
|
|
if completions is not None:
|
2025-02-28 23:02:40 +08:00
|
|
|
for j, completion_ids in enumerate(completions):
|
|
|
|
# Calculate which prompt this completion belongs to
|
|
|
|
prompt_idx = i + (j // group_size)
|
|
|
|
if prompt_idx < total_samples: # Make sure we don't go out of bounds
|
|
|
|
batch_indices.append(prompt_idx)
|
|
|
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
|
|
|
all_completions.append(completion_ids)
|
|
|
|
all_completion_texts.append(completion_text)
|
|
|
|
mx.eval(completion_ids)
|
2025-02-21 23:02:27 +08:00
|
|
|
except Exception as e:
|
|
|
|
print(f"Generation error: {e}")
|
|
|
|
continue
|
2025-02-28 23:02:40 +08:00
|
|
|
|
2025-02-22 08:05:58 +08:00
|
|
|
mx.metal.clear_cache()
|
|
|
|
|
2025-02-28 23:02:40 +08:00
|
|
|
# If we didn't generate any completions, return early
|
|
|
|
if not all_completions:
|
2025-03-01 05:07:19 +08:00
|
|
|
raise ValueError("No completions were generated. Please check your model and inputs.")
|
2025-02-28 23:02:40 +08:00
|
|
|
|
|
|
|
# Create expanded prompts and answers based on actual generated completions
|
2025-02-03 19:05:29 +08:00
|
|
|
expanded_answers = []
|
|
|
|
expanded_prompts = []
|
2025-02-28 23:02:40 +08:00
|
|
|
|
|
|
|
# Group completions by their original prompt
|
|
|
|
unique_prompt_indices = sorted(set(batch_indices))
|
|
|
|
grouped_completions = {idx: [] for idx in unique_prompt_indices}
|
|
|
|
|
|
|
|
for i, completion_idx in enumerate(batch_indices):
|
|
|
|
grouped_completions[completion_idx].append(i)
|
|
|
|
|
|
|
|
# Rebuild completions in the correct order
|
|
|
|
ordered_completions = []
|
|
|
|
ordered_completion_texts = []
|
|
|
|
ordered_batch_indices = []
|
|
|
|
|
|
|
|
for prompt_idx in unique_prompt_indices:
|
|
|
|
completion_indices = grouped_completions[prompt_idx]
|
|
|
|
for idx in completion_indices:
|
|
|
|
ordered_completions.append(all_completions[idx])
|
|
|
|
ordered_completion_texts.append(all_completion_texts[idx])
|
|
|
|
ordered_batch_indices.append(prompt_idx)
|
|
|
|
|
|
|
|
# Add corresponding prompt and answer
|
|
|
|
expanded_prompts.append(prompt_text[prompt_idx])
|
|
|
|
expanded_answers.append(answer_text[prompt_idx])
|
|
|
|
|
|
|
|
all_completions = ordered_completions
|
|
|
|
all_completion_texts = ordered_completion_texts
|
|
|
|
batch_indices = ordered_batch_indices
|
|
|
|
|
|
|
|
# Continue with the rest of the function
|
2025-02-03 19:05:29 +08:00
|
|
|
max_length = max(ids.shape[0] for ids in all_completions)
|
|
|
|
padded_completions = []
|
|
|
|
attention_masks = []
|
|
|
|
|
|
|
|
for completion_ids in all_completions:
|
|
|
|
padding_length = max_length - completion_ids.shape[0]
|
|
|
|
if padding_length > 0:
|
|
|
|
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
|
|
|
|
padded_ids = mx.concatenate([completion_ids, padding])
|
|
|
|
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
|
|
|
|
else:
|
|
|
|
padded_ids = completion_ids
|
|
|
|
mask = mx.ones_like(completion_ids)
|
|
|
|
padded_completions.append(padded_ids)
|
|
|
|
attention_masks.append(mask)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
inputs = mx.stack(padded_completions)
|
|
|
|
attention_mask = mx.stack(attention_masks)
|
2025-01-31 06:55:34 +08:00
|
|
|
lengths = attention_mask.sum(axis=1)
|
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Current policy probabilities
|
2025-02-04 04:57:26 +08:00
|
|
|
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
2025-02-05 15:44:06 +08:00
|
|
|
mx.eval(token_log_probs)
|
2025-02-21 23:02:27 +08:00
|
|
|
|
2025-02-09 22:41:47 +08:00
|
|
|
if ref_model is None:
|
2025-02-04 04:57:26 +08:00
|
|
|
ref_token_log_probs = token_log_probs
|
2025-02-09 22:41:47 +08:00
|
|
|
else:
|
|
|
|
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
2025-02-12 18:07:53 +08:00
|
|
|
mx.eval(ref_token_log_probs)
|
2025-02-04 04:57:26 +08:00
|
|
|
|
|
|
|
max_len = max(x.shape[0] for x in token_log_probs)
|
|
|
|
padded_log_probs = []
|
|
|
|
padded_ref_log_probs = []
|
|
|
|
|
|
|
|
for i in range(len(token_log_probs)):
|
|
|
|
seq_len = token_log_probs[i].shape[0]
|
2025-02-09 22:41:47 +08:00
|
|
|
padding = mx.zeros((max_len - seq_len,))
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-04 04:57:26 +08:00
|
|
|
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
|
|
|
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
|
|
|
|
2025-02-11 02:45:19 +08:00
|
|
|
token_log_probs = mx.stack(padded_log_probs)
|
|
|
|
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-15 22:29:22 +08:00
|
|
|
# Create array to store rewards from each function
|
|
|
|
all_func_rewards = []
|
|
|
|
|
|
|
|
# Collect rewards from each function separately
|
2025-02-03 15:26:42 +08:00
|
|
|
for reward_func in reward_funcs:
|
2025-02-03 19:05:29 +08:00
|
|
|
func_rewards = mx.array(reward_func(
|
2025-02-04 02:43:49 +08:00
|
|
|
prompts=expanded_prompts,
|
2025-02-04 02:37:05 +08:00
|
|
|
completions=all_completion_texts,
|
2025-02-04 02:43:49 +08:00
|
|
|
answer=expanded_answers
|
2025-02-03 19:05:29 +08:00
|
|
|
))
|
2025-02-15 22:29:22 +08:00
|
|
|
all_func_rewards.append(func_rewards)
|
|
|
|
|
|
|
|
# Stack rewards to shape (num_samples, num_funcs)
|
|
|
|
rewards = mx.stack(all_func_rewards, axis=1)
|
|
|
|
|
|
|
|
# Apply weights and sum
|
|
|
|
if reward_weights is not None:
|
|
|
|
if len(reward_weights) != len(reward_funcs):
|
|
|
|
raise ValueError(
|
|
|
|
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
|
|
|
|
f"functions ({len(reward_funcs)})"
|
|
|
|
)
|
|
|
|
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
|
|
|
else:
|
|
|
|
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
2025-02-26 22:21:57 +08:00
|
|
|
|
|
|
|
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-28 23:02:40 +08:00
|
|
|
# Get number of unique prompts
|
|
|
|
num_unique_prompts = len(unique_prompt_indices)
|
|
|
|
|
|
|
|
# Reshape rewards based on actual groups
|
|
|
|
rewards_by_prompt = [[] for _ in range(num_unique_prompts)]
|
|
|
|
for i, prompt_idx in enumerate(batch_indices):
|
|
|
|
prompt_position = unique_prompt_indices.index(prompt_idx)
|
|
|
|
rewards_by_prompt[prompt_position].append(rewards[i])
|
|
|
|
|
|
|
|
# Calculate advantages within each group
|
|
|
|
advantages = mx.zeros_like(rewards)
|
|
|
|
for i, prompt_rewards in enumerate(rewards_by_prompt):
|
|
|
|
if len(prompt_rewards) > 1: # Only normalize if we have multiple samples
|
|
|
|
prompt_rewards = mx.array(prompt_rewards)
|
|
|
|
mean_reward = mx.mean(prompt_rewards)
|
|
|
|
std_reward = mx.std(prompt_rewards)
|
|
|
|
|
|
|
|
# Find indices for this prompt
|
|
|
|
indices = [j for j, idx in enumerate(batch_indices) if idx == unique_prompt_indices[i]]
|
|
|
|
for j, idx in enumerate(indices):
|
|
|
|
advantages[idx] = (prompt_rewards[j] - mean_reward) / (std_reward + epsilon)
|
|
|
|
else:
|
|
|
|
# If only one sample, advantage is 0
|
|
|
|
idx = batch_indices.index(unique_prompt_indices[i])
|
|
|
|
advantages[idx] = 0.0
|
2025-02-04 02:37:05 +08:00
|
|
|
|
|
|
|
# Compute KL divergence using Schulman's approximator
|
2025-02-15 22:38:51 +08:00
|
|
|
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
|
2025-02-04 02:37:05 +08:00
|
|
|
|
|
|
|
# Create mask for valid tokens
|
2025-01-31 06:55:34 +08:00
|
|
|
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
2025-02-04 02:37:05 +08:00
|
|
|
|
|
|
|
# Compute policy ratio
|
2025-02-11 02:45:19 +08:00
|
|
|
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-02-12 18:07:53 +08:00
|
|
|
# Compute per-token loss
|
2025-02-05 21:38:09 +08:00
|
|
|
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-02-09 22:41:47 +08:00
|
|
|
# Average over tokens
|
2025-02-10 00:13:05 +08:00
|
|
|
sequence_sums = per_token_loss.sum(axis=1)
|
|
|
|
sequence_lengths = length_mask.sum(axis=1)
|
|
|
|
loss = (sequence_sums / sequence_lengths).mean()
|
2025-02-04 02:37:05 +08:00
|
|
|
|
|
|
|
# Calculate mean KL divergence for metrics
|
2025-01-31 06:55:34 +08:00
|
|
|
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
2025-02-09 22:41:47 +08:00
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Collect reward metrics
|
2025-02-03 15:26:42 +08:00
|
|
|
reward_metrics = {}
|
|
|
|
for i, reward_func in enumerate(reward_funcs):
|
2025-02-04 02:47:40 +08:00
|
|
|
func_name = reward_func.__name__
|
2025-02-03 19:05:29 +08:00
|
|
|
func_rewards = mx.array(reward_func(
|
2025-02-04 02:43:49 +08:00
|
|
|
prompts=expanded_prompts,
|
2025-02-03 19:05:29 +08:00
|
|
|
completions=all_completion_texts,
|
2025-02-04 02:43:49 +08:00
|
|
|
answer=expanded_answers
|
2025-02-03 19:05:29 +08:00
|
|
|
))
|
2025-02-04 02:47:40 +08:00
|
|
|
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
2025-02-11 00:51:14 +08:00
|
|
|
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
2025-02-28 23:02:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
grouped_rewards_mean = mx.array([mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt])
|
|
|
|
grouped_rewards_std = mx.array([mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) for rewards in rewards_by_prompt])
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
metrics = {
|
2025-02-03 15:26:42 +08:00
|
|
|
'total_rewards_mean': mx.mean(rewards),
|
|
|
|
'total_rewards_std': mx.std(rewards),
|
2025-02-28 23:02:40 +08:00
|
|
|
'grouped_rewards_mean': mx.mean(grouped_rewards_mean),
|
|
|
|
'grouped_rewards_std': mx.mean(grouped_rewards_std),
|
2025-02-03 15:26:42 +08:00
|
|
|
'kl': mean_kl,
|
|
|
|
**reward_metrics
|
2025-01-31 06:55:34 +08:00
|
|
|
}
|
2025-02-21 23:02:27 +08:00
|
|
|
|
2025-02-28 23:02:40 +08:00
|
|
|
if is_validation and all_completion_texts:
|
2025-02-25 05:20:07 +08:00
|
|
|
print("\n=== Validation Sample Details ===")
|
2025-03-01 05:07:19 +08:00
|
|
|
|
|
|
|
# Print the input context (prompt)
|
|
|
|
last_prompt_idx = batch_indices[-1] if batch_indices else 0
|
|
|
|
|
|
|
|
if last_prompt_idx < len(prompt_text):
|
|
|
|
print(f"\n📋 Raw Prompt:\n{prompt_text[last_prompt_idx]}")
|
|
|
|
print("\n" + "="*10 + "\n")
|
|
|
|
|
|
|
|
# Get the actual tokenized prompt that was fed to the model
|
|
|
|
if last_prompt_idx < len(prompt_tokens):
|
|
|
|
actual_prompt = tokenizer.decode(prompt_tokens[last_prompt_idx])
|
|
|
|
print(f"\n🔄 Model Input:\n{actual_prompt}")
|
|
|
|
print("\n" + "="*10 + "\n")
|
|
|
|
|
2025-02-25 05:20:07 +08:00
|
|
|
print(f"\n📝 Generation:\n{all_completion_texts[-1]}")
|
|
|
|
print("\n" + "="*10 + "\n")
|
2025-02-28 23:02:40 +08:00
|
|
|
|
|
|
|
# Make sure we have a valid index for answer_text
|
|
|
|
if last_prompt_idx < len(answer_text):
|
|
|
|
print(f"\n✅ Answer:\n{answer_text[last_prompt_idx]}")
|
|
|
|
print("\n" + "="*10 + "\n")
|
|
|
|
|
|
|
|
# Only try to extract if r1_extract_xml_answer is defined
|
|
|
|
if 'r1_extract_xml_answer' in globals():
|
|
|
|
print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}")
|
|
|
|
print("\n" + "="*35 + "\n")
|
|
|
|
|
2025-02-05 15:44:06 +08:00
|
|
|
mx.metal.clear_cache()
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-02-10 00:13:05 +08:00
|
|
|
return loss, sequence_lengths.sum(), metrics
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
|
2025-02-12 18:07:53 +08:00
|
|
|
def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
|
|
|
|
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
|
|
|
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
|
|
|
|
|
|
|
def length_key(i):
|
|
|
|
return len(dataset[i][0]) + len(dataset[i][1])
|
|
|
|
|
|
|
|
idx = sorted(range(len(dataset)), key=length_key)
|
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
if len(dataset) < batch_size:
|
|
|
|
raise ValueError(
|
2025-02-03 17:08:28 +08:00
|
|
|
f"Dataset must have at least batch_size={batch_size} "
|
|
|
|
f"examples but only has {len(dataset)}."
|
2025-01-31 06:55:34 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
step = mx.distributed.init().size()
|
|
|
|
if batch_size % step != 0:
|
|
|
|
raise ValueError("The batch size must be divisible by the number of workers")
|
|
|
|
|
2025-02-12 18:07:53 +08:00
|
|
|
def batch_index_generator():
|
|
|
|
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
|
|
|
yield idx[i : i + batch_size : step]
|
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
while True:
|
2025-02-12 18:07:53 +08:00
|
|
|
indices = (
|
|
|
|
np.random.permutation(list(batch_index_generator())) if train
|
|
|
|
else batch_index_generator()
|
|
|
|
)
|
2025-02-03 17:08:28 +08:00
|
|
|
|
2025-02-12 18:07:53 +08:00
|
|
|
for batch_idx in indices:
|
|
|
|
current_batch = [dataset[j] for j in batch_idx]
|
2025-02-03 17:08:28 +08:00
|
|
|
|
2025-02-12 18:07:53 +08:00
|
|
|
prompts_tokens = [item[0] for item in current_batch]
|
|
|
|
answers_tokens = [item[1] for item in current_batch]
|
|
|
|
prompts_text = [item[2] for item in current_batch]
|
|
|
|
answers_text = [item[3] for item in current_batch]
|
|
|
|
|
|
|
|
if any(len(p) > max_seq_length for p in prompts_tokens):
|
2025-01-31 06:55:34 +08:00
|
|
|
print(
|
2025-02-03 17:08:28 +08:00
|
|
|
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
|
|
|
"Long prompts will be truncated."
|
2025-01-31 06:55:34 +08:00
|
|
|
)
|
2025-02-12 18:07:53 +08:00
|
|
|
|
|
|
|
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
if not train:
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_grpo(
|
2025-02-12 15:57:26 +08:00
|
|
|
model: nn.Module,
|
|
|
|
ref_model: Optional[nn.Module],
|
2025-01-31 06:55:34 +08:00
|
|
|
dataset,
|
|
|
|
tokenizer,
|
|
|
|
batch_size,
|
|
|
|
num_batches,
|
|
|
|
beta: float,
|
2025-02-03 17:08:28 +08:00
|
|
|
epsilon: float,
|
2025-01-31 06:55:34 +08:00
|
|
|
group_size: int,
|
2025-02-17 21:39:38 +08:00
|
|
|
max_seq_length: int,
|
|
|
|
max_tokens: int,
|
2025-02-15 22:29:22 +08:00
|
|
|
temperature: float,
|
2025-02-25 05:20:07 +08:00
|
|
|
reward_funcs: Optional[List[RewardFunctions]] = [
|
|
|
|
r1_accuracy_reward_func,
|
|
|
|
r1_int_reward_func,
|
|
|
|
r1_strict_format_reward_func,
|
|
|
|
r1_soft_format_reward_func,
|
|
|
|
r1_count_xml
|
|
|
|
],
|
2025-02-04 16:18:45 +08:00
|
|
|
loss_fn: callable = grpo_loss,
|
2025-02-03 17:08:28 +08:00
|
|
|
iterate_batches: callable = iterate_grpo_batches
|
2025-01-31 06:55:34 +08:00
|
|
|
):
|
|
|
|
all_losses = 0
|
|
|
|
ntokens = 0
|
2025-02-21 23:02:27 +08:00
|
|
|
all_metrics = None
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
for _, batch in zip(
|
|
|
|
index_iterator,
|
|
|
|
iterate_batches(
|
|
|
|
dataset=dataset,
|
|
|
|
batch_size=batch_size,
|
|
|
|
max_seq_length=max_seq_length,
|
|
|
|
),
|
|
|
|
):
|
2025-02-04 16:18:45 +08:00
|
|
|
losses, toks, metrics = loss_fn(
|
2025-01-31 23:27:31 +08:00
|
|
|
model=model,
|
|
|
|
tokenizer=tokenizer,
|
2025-02-03 19:05:29 +08:00
|
|
|
batch=batch,
|
2025-01-31 23:27:31 +08:00
|
|
|
reward_funcs=reward_funcs,
|
|
|
|
beta=beta,
|
|
|
|
group_size=group_size,
|
2025-02-03 17:08:28 +08:00
|
|
|
epsilon=epsilon,
|
2025-02-15 22:29:22 +08:00
|
|
|
ref_model=ref_model,
|
2025-02-17 21:39:38 +08:00
|
|
|
temperature=temperature,
|
2025-02-21 23:02:27 +08:00
|
|
|
max_tokens=max_tokens,
|
|
|
|
is_validation=True
|
2025-01-31 23:27:31 +08:00
|
|
|
)
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-01-31 23:27:31 +08:00
|
|
|
all_losses += losses * toks
|
|
|
|
ntokens += toks
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
if all_metrics is None:
|
|
|
|
all_metrics = {k: v * toks for k, v in metrics.items()}
|
|
|
|
else:
|
|
|
|
for k, v in metrics.items():
|
|
|
|
all_metrics[k] += v * toks
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-01-31 23:27:31 +08:00
|
|
|
mx.eval(all_losses, ntokens)
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-02-11 16:26:43 +08:00
|
|
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
|
|
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
2025-01-31 23:54:18 +08:00
|
|
|
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
|
|
|
avg_loss = (all_losses / ntokens).item()
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
return avg_loss, ntokens, avg_metrics
|
2025-01-31 23:27:31 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-01-31 23:57:43 +08:00
|
|
|
def train_grpo(
|
2025-02-12 15:57:26 +08:00
|
|
|
model: nn.Module,
|
|
|
|
ref_model: Optional[nn.Module],
|
2025-01-31 06:55:34 +08:00
|
|
|
tokenizer,
|
|
|
|
optimizer,
|
2025-02-12 18:07:53 +08:00
|
|
|
train_dataset,
|
|
|
|
val_dataset,
|
|
|
|
reward_funcs: Optional[List[RewardFunctions]] = [
|
2025-02-03 16:13:17 +08:00
|
|
|
r1_accuracy_reward_func,
|
|
|
|
r1_int_reward_func,
|
|
|
|
r1_strict_format_reward_func,
|
|
|
|
r1_soft_format_reward_func,
|
|
|
|
r1_count_xml
|
|
|
|
],
|
2025-01-31 06:55:34 +08:00
|
|
|
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
2025-02-04 16:18:45 +08:00
|
|
|
loss_fn: callable = grpo_loss,
|
2025-02-03 17:08:28 +08:00
|
|
|
iterate_batches: callable = iterate_grpo_batches,
|
2025-01-31 06:55:34 +08:00
|
|
|
training_callback: TrainingCallback = None,
|
|
|
|
):
|
2025-02-03 17:08:28 +08:00
|
|
|
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
|
2025-01-31 06:55:34 +08:00
|
|
|
world = mx.distributed.init()
|
|
|
|
world_size = world.size()
|
|
|
|
rank = world.rank()
|
|
|
|
if world_size > 1:
|
|
|
|
print(f"Node {rank} of {world_size}")
|
|
|
|
|
|
|
|
if args.grad_checkpoint:
|
|
|
|
grad_checkpoint(model.layers[0])
|
|
|
|
|
|
|
|
state = [model.state, optimizer.state]
|
|
|
|
|
|
|
|
def step(batch):
|
2025-02-03 19:05:29 +08:00
|
|
|
(loss, toks, metrics), grad = loss_value_and_grad(
|
|
|
|
model,
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
batch=batch,
|
|
|
|
reward_funcs=reward_funcs,
|
|
|
|
beta=args.beta,
|
|
|
|
group_size=args.group_size,
|
|
|
|
epsilon=args.epsilon,
|
2025-02-04 02:37:05 +08:00
|
|
|
ref_model=ref_model,
|
2025-02-04 16:18:45 +08:00
|
|
|
max_tokens=args.max_completion_length,
|
2025-02-15 22:29:22 +08:00
|
|
|
temperature=args.temperature
|
2025-02-03 19:05:29 +08:00
|
|
|
)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
grad = average_gradients(grad)
|
|
|
|
|
|
|
|
optimizer.update(model, grad)
|
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
return loss, toks, metrics
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-02-04 16:18:45 +08:00
|
|
|
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
losses = 0
|
|
|
|
n_tokens = 0
|
|
|
|
steps = 0
|
|
|
|
trained_tokens = 0
|
2025-01-31 23:54:18 +08:00
|
|
|
accumulated_metrics = {
|
2025-02-04 16:18:45 +08:00
|
|
|
'total_rewards_mean': 0,
|
|
|
|
'total_rewards_std': 0,
|
|
|
|
'grouped_rewards_mean': 0,
|
2025-01-31 23:54:18 +08:00
|
|
|
'grouped_rewards_std': 0,
|
|
|
|
'kl': 0
|
|
|
|
}
|
2025-02-04 16:18:45 +08:00
|
|
|
for reward_func in reward_funcs:
|
|
|
|
func_name = reward_func.__name__
|
|
|
|
accumulated_metrics[f'{func_name}_mean'] = 0
|
|
|
|
accumulated_metrics[f'{func_name}_std'] = 0
|
2025-01-31 23:54:18 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
start = time.perf_counter()
|
|
|
|
for it, batch in zip(
|
|
|
|
range(1, args.iters + 1),
|
|
|
|
iterate_batches(
|
|
|
|
dataset=train_dataset,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
max_seq_length=args.max_seq_length,
|
|
|
|
train=True,
|
|
|
|
),
|
|
|
|
):
|
|
|
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
|
|
|
stop = time.perf_counter()
|
2025-01-31 23:57:43 +08:00
|
|
|
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
2025-01-31 06:55:34 +08:00
|
|
|
model=model,
|
|
|
|
dataset=val_dataset,
|
2025-02-04 16:18:45 +08:00
|
|
|
loss_fn=loss_fn,
|
2025-02-03 17:08:28 +08:00
|
|
|
ref_model=ref_model,
|
2025-02-03 16:13:17 +08:00
|
|
|
reward_funcs=reward_funcs,
|
2025-01-31 06:55:34 +08:00
|
|
|
tokenizer=tokenizer,
|
2025-02-03 17:08:28 +08:00
|
|
|
group_size=args.group_size,
|
2025-01-31 06:55:34 +08:00
|
|
|
batch_size=args.batch_size,
|
|
|
|
num_batches=args.val_batches,
|
|
|
|
max_seq_length=args.max_seq_length,
|
2025-02-17 21:39:38 +08:00
|
|
|
max_tokens=args.max_completion_length,
|
2025-02-03 17:08:28 +08:00
|
|
|
beta=args.beta,
|
|
|
|
epsilon=args.epsilon,
|
2025-02-15 22:29:22 +08:00
|
|
|
temperature=args.temperature,
|
2025-01-31 06:55:34 +08:00
|
|
|
iterate_batches=iterate_batches,
|
|
|
|
)
|
|
|
|
val_time = time.perf_counter() - stop
|
|
|
|
if rank == 0:
|
2025-02-03 17:08:28 +08:00
|
|
|
val_metrics_str = (
|
2025-02-25 05:20:07 +08:00
|
|
|
f"Val loss {val_loss:.3f}, "
|
2025-02-03 17:08:28 +08:00
|
|
|
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
|
|
|
|
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
|
|
|
|
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
|
2025-01-31 23:54:18 +08:00
|
|
|
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
|
2025-02-03 17:08:28 +08:00
|
|
|
f"Val kl {val_metrics['kl']:.3f}"
|
|
|
|
)
|
|
|
|
|
2025-02-04 02:56:11 +08:00
|
|
|
for i, reward_func in enumerate(reward_funcs):
|
2025-02-03 17:08:28 +08:00
|
|
|
val_metrics_str += (
|
2025-02-04 02:56:11 +08:00
|
|
|
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
2025-02-11 02:45:19 +08:00
|
|
|
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
2025-02-03 17:08:28 +08:00
|
|
|
)
|
2025-02-04 02:56:11 +08:00
|
|
|
|
2025-02-03 17:08:28 +08:00
|
|
|
print(
|
|
|
|
f"Iter {it}: {val_metrics_str}, "
|
2025-01-31 06:55:34 +08:00
|
|
|
f"Val took {val_time:.3f}s",
|
|
|
|
flush=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
if training_callback is not None:
|
2025-01-31 23:54:18 +08:00
|
|
|
training_callback.on_val_loss_report({
|
2025-01-31 06:55:34 +08:00
|
|
|
"iteration": it,
|
|
|
|
"val_loss": val_loss,
|
2025-01-31 23:54:18 +08:00
|
|
|
**{f"val_{k}": v for k, v in val_metrics.items()},
|
2025-01-31 06:55:34 +08:00
|
|
|
"val_time": val_time,
|
2025-01-31 23:54:18 +08:00
|
|
|
})
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
start = time.perf_counter()
|
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
loss, toks, metrics = step(batch)
|
|
|
|
losses += loss
|
2025-01-31 06:55:34 +08:00
|
|
|
n_tokens += toks
|
|
|
|
steps += 1
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
for k, v in metrics.items():
|
|
|
|
accumulated_metrics[k] += v
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
mx.eval(state, losses, n_tokens)
|
|
|
|
|
|
|
|
if it % args.steps_per_report == 0 or it == args.iters:
|
|
|
|
stop = time.perf_counter()
|
|
|
|
|
|
|
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
|
|
|
train_loss /= steps * mx.distributed.init().size()
|
2025-01-31 23:54:18 +08:00
|
|
|
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
|
2025-01-31 06:55:34 +08:00
|
|
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
|
|
|
learning_rate = optimizer.learning_rate.item()
|
|
|
|
it_sec = args.steps_per_report / (stop - start)
|
|
|
|
tokens_sec = float(n_tokens) / (stop - start)
|
|
|
|
trained_tokens += n_tokens
|
|
|
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
2025-01-31 23:54:18 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
if rank == 0:
|
2025-02-03 17:08:28 +08:00
|
|
|
train_metrics_str = (
|
2025-02-25 05:20:07 +08:00
|
|
|
f"Train loss {train_loss:.3f}, "
|
2025-02-03 17:08:28 +08:00
|
|
|
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
|
|
|
|
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
|
|
|
|
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
|
|
|
|
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
|
|
|
|
f"KL {avg_metrics['kl']:.3f}"
|
|
|
|
)
|
|
|
|
|
2025-02-04 02:56:11 +08:00
|
|
|
for i, reward_func in enumerate(reward_funcs):
|
|
|
|
func_name = reward_func.__name__
|
2025-02-03 17:08:28 +08:00
|
|
|
train_metrics_str += (
|
2025-02-04 16:18:45 +08:00
|
|
|
f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
|
|
|
|
f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
|
2025-02-03 17:08:28 +08:00
|
|
|
)
|
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
print(
|
2025-02-03 17:08:28 +08:00
|
|
|
f"Iter {it}: {train_metrics_str}, "
|
2025-01-31 06:55:34 +08:00
|
|
|
f"Learning Rate {learning_rate:.3e}, "
|
|
|
|
f"It/sec {it_sec:.3f}, "
|
|
|
|
f"Tokens/sec {tokens_sec:.3f}, "
|
|
|
|
f"Peak mem {peak_mem:.3f} GB",
|
|
|
|
flush=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
if training_callback is not None:
|
2025-01-31 23:54:18 +08:00
|
|
|
training_callback.on_train_loss_report({
|
2025-01-31 06:55:34 +08:00
|
|
|
"iteration": it,
|
|
|
|
"train_loss": train_loss,
|
2025-01-31 23:54:18 +08:00
|
|
|
**{f"train_{k}": v for k, v in avg_metrics.items()},
|
2025-01-31 06:55:34 +08:00
|
|
|
"learning_rate": learning_rate,
|
|
|
|
"iterations_per_second": it_sec,
|
|
|
|
"tokens_per_second": tokens_sec,
|
|
|
|
"trained_tokens": trained_tokens,
|
|
|
|
"peak_memory": peak_mem,
|
2025-01-31 23:54:18 +08:00
|
|
|
})
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
losses = 0
|
|
|
|
n_tokens = 0
|
|
|
|
steps = 0
|
|
|
|
start = time.perf_counter()
|
|
|
|
|
|
|
|
if it % args.steps_per_save == 0:
|
|
|
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
|
|
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
|
|
|
checkpoint = (
|
|
|
|
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
|
|
|
)
|
|
|
|
mx.save_safetensors(str(checkpoint), adapter_weights)
|
|
|
|
print(
|
|
|
|
f"Iter {it}: Saved adapter weights to "
|
|
|
|
f"{args.adapter_file} and {checkpoint}."
|
|
|
|
)
|
|
|
|
|
|
|
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
|
|
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
2025-01-31 23:57:43 +08:00
|
|
|
print(f"Saved final weights to {args.adapter_file}.")
|