diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index 4bb39832..4d79b1ac 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -1,21 +1,21 @@
# Copyright © 2024 Apple Inc.
+from pathlib import Path
import argparse
+import types
import math
import os
import re
-import types
-from pathlib import Path
-import mlx.nn as nn
import mlx.optimizers as optim
+import mlx.nn as nn
import numpy as np
import yaml
+from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
+from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
-from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
-from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
@@ -73,6 +73,8 @@ CONFIG_DEFAULTS = {
"max_completion_length": 512,
"use_chat_template": False,
"use_prompt": False,
+ "temperature": 1.0,
+ "reward_weights": None,
}
@@ -224,6 +226,18 @@ def build_parser():
help="Rather to use the prompt from the R1 paper.",
default=None,
)
+ parser.add_argument(
+ "--temperature",
+ type=float,
+ help="Temperature for sampling. The higher the temperature, the more random the completions.",
+ default=1.0,
+ )
+ parser.add_argument(
+ "--reward-weights",
+ type=str,
+ help="Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`.",
+ default=None,
+ )
return parser
def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback):
@@ -241,7 +255,9 @@ def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
- reference_model_path=args.reference_model_path
+ reference_model_path=args.reference_model_path,
+ temperature=args.temperature,
+ reward_weights=[float(x) for x in args.reward_weights.strip('[]').split(',')] if args.reward_weights else None
)
if args.reference_model_path:
diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py
index d0fa5fae..e96b8f29 100644
--- a/llms/mlx_lm/tuner/grpo_trainer.py
+++ b/llms/mlx_lm/tuner/grpo_trainer.py
@@ -35,13 +35,24 @@ class GRPOTrainingArgs(TrainingArgs):
"help": "Path to reference model weights. If None, uses the same model."
}
)
+ temperature: float = field(
+ default=1.0,
+ metadata={
+ "help": "Temperature for sampling. The higher the temperature, the more random the completions."
+ }
+ )
+ reward_weights: Optional[List[float]] = field(
+ default=None,
+ metadata={
+ "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
+ }
+ )
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
def r1_extract_xml_answer(text: str) -> str:
- """Extracts the answer from an XML formatted text string."""
try:
answer = text.split("")[-1]
answer = answer.split("")[0]
@@ -52,14 +63,12 @@ def r1_extract_xml_answer(text: str) -> str:
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Ensures we always return a list of floats."""
if not completions or not answer:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
@@ -67,7 +76,6 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
pattern = r".*?\s*.*?"
@@ -76,7 +84,6 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
pattern = r"^\n.*?\n\n\n.*?\n\n$"
@@ -85,16 +92,13 @@ def r1_strict_format_reward_func(prompts: list, completions: list, answer: list,
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
- """Ensures we always return a list of floats."""
if not completions:
return [0.0] * len(prompts)
-
scores = []
for text in completions:
if not text:
scores.append(0.0)
continue
-
count = 0.0
if text.count("\n") == 1:
count += 0.125
@@ -104,13 +108,9 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
-
- # Penalize extra text after
end_text = text.split("\n\n")[-1]
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
-
- scores.append(max(0.0, count)) # Ensure non-negative score
-
+ scores.append(max(0.0, count))
return scores
@@ -119,22 +119,18 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
prompt = prompt[None, :]
if prompt.shape[1] == 0:
return None
-
end_sequence = tokenizer.encode("")
end_sequence_length = len(end_sequence)
-
initial_length = prompt.shape[1]
output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32)
output[:initial_length] = prompt[0]
current_length = initial_length
-
try:
def sample(logits):
if temperature > 0:
logits /= temperature
logprobs = logits - mx.logsumexp(logits, keepdims=True)
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
-
for _ in range(max_tokens):
current_input = output[:current_length][None, :]
logits = model(current_input)
@@ -143,18 +139,14 @@ def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature):
token_value = next_token.item()
output[current_length] = token_value
current_length += 1
-
if token_value == tokenizer.eos_token_id:
break
-
if current_length >= end_sequence_length:
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
if last_tokens == end_sequence:
break
-
if current_length > initial_length:
return output[:current_length]
-
except Exception as e:
print(f"Generation error: {str(e)}")
return None
@@ -192,9 +184,10 @@ def grpo_loss(
group_size=4,
epsilon=1e-4,
max_tokens=64,
- temperature=1.0
+ temperature=1.0,
+ reward_weights=None
):
- prompt_tokens, answer_tokens, prompt_text, answer_text = batch
+ prompt_tokens, _, prompt_text, answer_text = batch
batch_size = len(prompt_tokens)
all_completions = []
@@ -273,18 +266,34 @@ def grpo_loss(
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
- # Rewards and advantages
- rewards = mx.zeros((len(all_completions),))
+ # Create array to store rewards from each function
+ all_func_rewards = []
+
+ # Collect rewards from each function separately
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(
prompts=expanded_prompts,
completions=all_completion_texts,
answer=expanded_answers
))
- rewards += func_rewards
+ all_func_rewards.append(func_rewards)
- if len(reward_funcs) > 1:
- rewards /= len(reward_funcs)
+ # Stack rewards to shape (num_samples, num_funcs)
+ rewards = mx.stack(all_func_rewards, axis=1)
+ print(f"Rewards: {rewards}")
+
+ # Apply weights and sum
+ if reward_weights is not None:
+ if len(reward_weights) != len(reward_funcs):
+ raise ValueError(
+ f"Number of reward weights ({len(reward_weights)}) must match number of reward "
+ f"functions ({len(reward_funcs)})"
+ )
+ reward_weights = mx.array(reward_weights, dtype=mx.float32)
+ else:
+ reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
+ rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
+ print(f"Rewards after weights: {rewards}")
# Reshape rewards and compute advantages
rewards_reshaped = rewards.reshape(batch_size, group_size)
@@ -397,15 +406,11 @@ def evaluate_grpo(
epsilon: float,
group_size: int,
max_seq_length,
+ temperature: float,
reward_funcs: Optional[List[RewardFunctions]] = None,
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
- """
- Evaluate model using GRPO loss.
- Returns:
- tuple: (average loss, number of tokens, average metrics)
- """
all_losses = 0
ntokens = 0
all_metrics = None
@@ -428,7 +433,8 @@ def evaluate_grpo(
beta=beta,
group_size=group_size,
epsilon=epsilon,
- ref_model=ref_model
+ ref_model=ref_model,
+ temperature=temperature
)
all_losses += losses * toks
@@ -442,12 +448,10 @@ def evaluate_grpo(
mx.eval(all_losses, ntokens)
- # Aggregate across distributed workers
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
- # Calculate averages
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_loss = (all_losses / ntokens).item()
@@ -486,8 +490,6 @@ def train_grpo(
state = [model.state, optimizer.state]
def step(batch):
-
- # Forward and backward pass
(loss, toks, metrics), grad = loss_value_and_grad(
model,
tokenizer=tokenizer,
@@ -498,12 +500,11 @@ def train_grpo(
epsilon=args.epsilon,
ref_model=ref_model,
max_tokens=args.max_completion_length,
+ temperature=args.temperature
)
- # All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
- # Model update
optimizer.update(model, grad)
return loss, toks, metrics
@@ -536,8 +537,6 @@ def train_grpo(
train=True,
),
):
- # Report validation loss if needed, the first validation loss
- # is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_ntokens, val_metrics = evaluate_grpo(
@@ -553,6 +552,7 @@ def train_grpo(
max_seq_length=args.max_seq_length,
beta=args.beta,
epsilon=args.epsilon,
+ temperature=args.temperature,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
@@ -566,7 +566,6 @@ def train_grpo(
f"Val kl {val_metrics['kl']:.3f}"
)
- # Add reward function specific metrics
for i, reward_func in enumerate(reward_funcs):
val_metrics_str += (
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
@@ -622,7 +621,6 @@ def train_grpo(
f"KL {avg_metrics['kl']:.3f}"
)
- # Add reward function specific metrics
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
train_metrics_str += (
@@ -656,7 +654,6 @@ def train_grpo(
steps = 0
start = time.perf_counter()
- # Save adapter weights
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
@@ -669,7 +666,6 @@ def train_grpo(
f"{args.adapter_file} and {checkpoint}."
)
- # Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")
\ No newline at end of file