2024-03-12 22:37:40 +08:00
# Copyright © 2024 Apple Inc.
2025-02-15 22:29:22 +08:00
from pathlib import Path
2024-01-24 00:44:37 +08:00
import argparse
2025-02-15 22:29:22 +08:00
import types
2024-01-24 00:44:37 +08:00
import math
2025-01-04 02:50:59 +08:00
import os
2024-03-08 23:57:52 +08:00
import re
2024-01-24 00:44:37 +08:00
import mlx . optimizers as optim
2025-02-15 22:29:22 +08:00
import mlx . nn as nn
2024-01-24 00:44:37 +08:00
import numpy as np
2024-03-08 23:57:52 +08:00
import yaml
2024-01-24 00:44:37 +08:00
2025-02-15 22:29:22 +08:00
from . tuner . grpo_trainer import GRPOTrainingArgs , evaluate_grpo , train_grpo
from . tuner . trainer import TrainingArgs , TrainingCallback , evaluate , train
2024-06-02 21:38:42 +08:00
from . tokenizer_utils import TokenizerWrapper
2024-03-20 07:45:46 +08:00
from . tuner . datasets import load_dataset
2024-06-02 21:38:42 +08:00
from . tuner . utils import (
build_schedule ,
linear_to_lora_layers ,
2024-09-30 08:12:47 +08:00
load_adapters ,
2024-06-02 21:38:42 +08:00
print_trainable_parameters ,
)
2024-04-03 04:52:53 +08:00
from . utils import load , save_config
2024-02-06 13:13:49 +08:00
2024-03-08 23:57:52 +08:00
yaml_loader = yaml . SafeLoader
yaml_loader . add_implicit_resolver (
" tag:yaml.org,2002:float " ,
re . compile (
""" ^(?:
[ - + ] ? ( ? : [ 0 - 9 ] [ 0 - 9 _ ] * ) \\. [ 0 - 9 _ ] * ( ? : [ eE ] [ - + ] ? [ 0 - 9 ] + ) ?
| [ - + ] ? ( ? : [ 0 - 9 ] [ 0 - 9 _ ] * ) ( ? : [ eE ] [ - + ] ? [ 0 - 9 ] + )
| \\. [ 0 - 9 _ ] + ( ? : [ eE ] [ - + ] [ 0 - 9 ] + ) ?
| [ - + ] ? [ 0 - 9 ] [ 0 - 9 _ ] * ( ? : : [ 0 - 5 ] ? [ 0 - 9 ] ) + \\. [ 0 - 9 _ ] *
| [ - + ] ? \\. ( ? : inf | Inf | INF )
| \\. ( ? : nan | NaN | NAN ) ) $ """ ,
re . X ,
) ,
list ( " -+0123456789. " ) ,
)
CONFIG_DEFAULTS = {
" model " : " mlx_model " ,
" train " : False ,
2025-02-01 04:10:44 +08:00
" training_mode " : " normal " ,
2024-09-30 08:12:47 +08:00
" fine_tune_type " : " lora " ,
2025-03-06 05:54:54 +08:00
" optimizer " : " adam " ,
" optimizer_config " : {
" adam " : { } ,
" adamw " : { } ,
} ,
2024-03-08 23:57:52 +08:00
" data " : " data/ " ,
" seed " : 0 ,
2024-09-30 08:12:47 +08:00
" num_layers " : 16 ,
2024-03-08 23:57:52 +08:00
" batch_size " : 4 ,
" iters " : 1000 ,
" val_batches " : 25 ,
" learning_rate " : 1e-5 ,
" steps_per_report " : 10 ,
" steps_per_eval " : 200 ,
" resume_adapter_file " : None ,
2024-04-03 04:52:53 +08:00
" adapter_path " : " adapters " ,
2024-03-08 23:57:52 +08:00
" save_every " : 100 ,
" test " : False ,
" test_batches " : 500 ,
" max_seq_length " : 2048 ,
2025-01-10 03:33:54 +08:00
" config " : None ,
" grad_checkpoint " : False ,
2024-03-30 04:41:10 +08:00
" lr_schedule " : None ,
2025-02-04 16:18:45 +08:00
" lora_parameters " : { " rank " : 8 , " alpha " : 16 , " dropout " : 0.0 , " scale " : 10.0 } ,
2025-03-01 03:33:04 +08:00
" mask_prompt " : False ,
2025-02-04 16:18:45 +08:00
# GRPO args
2025-02-01 04:10:44 +08:00
" reference_model_path " : None ,
" group_size " : 4 ,
" beta " : 0.1 ,
" epsilon " : 1e-4 ,
2025-02-04 16:18:45 +08:00
" max_completion_length " : 512 ,
" use_chat_template " : False ,
" use_prompt " : False ,
2025-02-15 22:29:22 +08:00
" temperature " : 1.0 ,
2025-03-01 04:16:02 +08:00
" reward_weights " : None
2024-03-08 23:57:52 +08:00
}
2024-01-24 00:44:37 +08:00
def build_parser ( ) :
parser = argparse . ArgumentParser ( description = " LoRA or QLoRA finetuning. " )
parser . add_argument (
" --model " ,
2025-01-10 03:33:54 +08:00
type = str ,
2024-01-24 00:44:37 +08:00
help = " The path to the local model directory or Hugging Face repo. " ,
)
# Training args
parser . add_argument (
" --train " ,
action = " store_true " ,
help = " Do training " ,
2025-01-17 03:15:42 +08:00
default = None ,
2024-01-24 00:44:37 +08:00
)
parser . add_argument (
" --data " ,
type = str ,
2024-09-30 22:36:21 +08:00
help = (
" Directory with { train, valid, test}.jsonl files or the name "
" of a Hugging Face dataset (e.g., ' mlx-community/wikisql ' ) "
) ,
2024-01-24 00:44:37 +08:00
)
parser . add_argument (
2024-09-30 08:12:47 +08:00
" --fine-tune-type " ,
type = str ,
choices = [ " lora " , " dora " , " full " ] ,
help = " Type of fine-tuning to perform: lora, dora, or full. " ,
)
2025-03-06 05:54:54 +08:00
parser . add_argument (
" --optimizer " ,
type = str ,
choices = [ " adam " , " adamw " ] ,
default = None ,
help = " Optimizer to use for training: adam or adamw " ,
)
2025-02-10 12:12:34 +08:00
parser . add_argument (
" --mask-prompt " ,
action = " store_true " ,
help = " Mask the prompt in the loss when training " ,
2025-03-01 03:33:04 +08:00
default = None ,
2025-02-10 12:12:34 +08:00
)
2025-02-01 04:10:44 +08:00
parser . add_argument (
" --training-mode " ,
type = str ,
choices = [ " normal " , " grpo " ] ,
help = " Training mode: normal or GRPO " ,
)
2024-09-30 08:12:47 +08:00
parser . add_argument (
" --num-layers " ,
2024-01-24 00:44:37 +08:00
type = int ,
2024-05-22 11:09:35 +08:00
help = " Number of layers to fine-tune. Default is 16, use -1 for all. " ,
2024-01-24 00:44:37 +08:00
)
2024-03-08 23:57:52 +08:00
parser . add_argument ( " --batch-size " , type = int , help = " Minibatch size. " )
parser . add_argument ( " --iters " , type = int , help = " Iterations to train for. " )
2024-01-24 00:44:37 +08:00
parser . add_argument (
" --val-batches " ,
type = int ,
help = " Number of validation batches, -1 uses the entire validation set. " ,
)
2024-03-08 23:57:52 +08:00
parser . add_argument ( " --learning-rate " , type = float , help = " Adam learning rate. " )
2024-01-24 00:44:37 +08:00
parser . add_argument (
" --steps-per-report " ,
type = int ,
help = " Number of training steps between loss reporting. " ,
)
parser . add_argument (
" --steps-per-eval " ,
type = int ,
help = " Number of training steps between validations. " ,
)
parser . add_argument (
" --resume-adapter-file " ,
type = str ,
2024-09-30 08:12:47 +08:00
help = " Load path to resume training from the given fine-tuned weights. " ,
2024-01-24 00:44:37 +08:00
)
parser . add_argument (
2024-04-03 04:52:53 +08:00
" --adapter-path " ,
2024-01-24 00:44:37 +08:00
type = str ,
2024-09-30 08:12:47 +08:00
help = " Save/load path for the fine-tuned weights. " ,
2024-01-24 00:44:37 +08:00
)
parser . add_argument (
" --save-every " ,
type = int ,
help = " Save the model every N iterations. " ,
)
parser . add_argument (
" --test " ,
action = " store_true " ,
help = " Evaluate on the test set after training " ,
2025-01-17 03:15:42 +08:00
default = None ,
2024-01-24 00:44:37 +08:00
)
parser . add_argument (
" --test-batches " ,
type = int ,
help = " Number of test set batches, -1 uses the entire test set. " ,
)
2024-02-05 04:28:21 +08:00
parser . add_argument (
2024-02-18 22:04:49 +08:00
" --max-seq-length " ,
2024-02-05 04:28:21 +08:00
type = int ,
help = " Maximum sequence length. " ,
)
2024-03-08 23:57:52 +08:00
parser . add_argument (
" -c " ,
" --config " ,
2025-01-10 03:33:54 +08:00
type = str ,
2024-03-08 23:57:52 +08:00
help = " A YAML configuration file with the training options " ,
)
2024-03-13 11:02:03 +08:00
parser . add_argument (
" --grad-checkpoint " ,
action = " store_true " ,
help = " Use gradient checkpointing to reduce memory use. " ,
2025-01-17 03:15:42 +08:00
default = None ,
2024-03-13 11:02:03 +08:00
)
2025-01-10 03:33:54 +08:00
parser . add_argument ( " --seed " , type = int , help = " The PRNG seed " )
2025-02-01 04:10:44 +08:00
2025-02-03 17:08:28 +08:00
# GRPO args
2025-02-01 04:10:44 +08:00
parser . add_argument (
" --group-size " ,
type = int ,
2025-02-04 16:18:45 +08:00
help = " Number of generations. " ,
2025-02-01 04:10:44 +08:00
default = 4 ,
)
2025-02-04 16:18:45 +08:00
parser . add_argument (
" --max-completion-length " ,
type = int ,
help = " Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. " ,
default = 512 ,
)
2025-02-01 04:10:44 +08:00
parser . add_argument (
" --beta " ,
type = float ,
help = " KL penalty coefficient. " ,
default = 0.1 ,
)
parser . add_argument (
" --epsilon " ,
type = float ,
help = " The Epsilon for numerical stability. " ,
default = 1e-4 ,
)
2025-02-04 16:18:45 +08:00
parser . add_argument (
" --use-chat-template " ,
2025-02-05 16:48:00 +08:00
action = " store_true " ,
2025-02-04 16:18:45 +08:00
help = " If the model is a Chat model, use the Chat template. " ,
2025-02-05 16:48:00 +08:00
default = None ,
2025-02-04 16:18:45 +08:00
)
parser . add_argument (
" --use-prompt " ,
2025-02-05 16:48:00 +08:00
action = " store_true " ,
help = " Rather to use the prompt from the R1 paper. " ,
default = None ,
2025-02-04 16:18:45 +08:00
)
2025-02-15 22:29:22 +08:00
parser . add_argument (
" --temperature " ,
type = float ,
help = " Temperature for sampling. The higher the temperature, the more random the completions. " ,
default = 1.0 ,
)
parser . add_argument (
" --reward-weights " ,
type = str ,
help = " Weights for each reward function. Must match the number of reward functions and be in this format [0.1, 0.2, 0.3, 0.4, 0.5]. If not given, all rewards are weighted equally with weight `1.0`. " ,
default = None ,
)
2024-01-24 00:44:37 +08:00
return parser
2025-02-14 16:09:59 +08:00
def train_model_grpo ( model , tokenizer , args , opt , train_set , valid_set , adapter_file , training_callback ) :
training_args = GRPOTrainingArgs (
batch_size = args . batch_size ,
iters = args . iters ,
val_batches = args . val_batches ,
steps_per_report = args . steps_per_report ,
steps_per_eval = args . steps_per_eval ,
steps_per_save = args . save_every ,
adapter_file = adapter_file ,
max_seq_length = args . max_seq_length ,
max_completion_length = args . max_completion_length ,
grad_checkpoint = args . grad_checkpoint ,
beta = args . beta ,
group_size = args . group_size ,
epsilon = args . epsilon ,
2025-02-15 22:29:22 +08:00
reference_model_path = args . reference_model_path ,
temperature = args . temperature ,
reward_weights = [ float ( x ) for x in args . reward_weights . strip ( ' [] ' ) . split ( ' , ' ) ] if args . reward_weights else None
2025-02-14 16:09:59 +08:00
)
if args . reference_model_path :
reference_model , _ = load ( args . reference_model_path )
else :
reference_model , _ = load ( args . model )
train_grpo (
model = model ,
ref_model = reference_model . freeze ( ) ,
tokenizer = tokenizer ,
optimizer = opt ,
train_dataset = train_set ,
val_dataset = valid_set ,
args = training_args ,
training_callback = training_callback ,
)
2024-01-24 00:44:37 +08:00
2024-06-02 21:38:42 +08:00
def train_model (
args ,
model : nn . Module ,
tokenizer : TokenizerWrapper ,
train_set ,
valid_set ,
training_callback : TrainingCallback = None ,
) :
model . freeze ( )
2025-02-21 05:32:01 +08:00
if args . num_layers > len ( model . layers ) :
raise ValueError (
f " Requested to train { args . num_layers } layers "
f " but the model only has { len ( model . layers ) } layers. "
)
2024-09-30 08:12:47 +08:00
if args . fine_tune_type == " full " :
2025-02-21 05:32:01 +08:00
for l in model . layers [ - max ( args . num_layers , 0 ) : ] :
2024-09-30 08:12:47 +08:00
l . unfreeze ( )
elif args . fine_tune_type in [ " lora " , " dora " ] :
# Convert linear layers to lora/dora layers and unfreeze in the process
linear_to_lora_layers (
model ,
args . num_layers ,
args . lora_parameters ,
use_dora = ( args . fine_tune_type == " dora " ) ,
)
else :
raise ValueError ( f " Received unknown fine-tune-type { args . fine_tune_type } " )
2024-06-02 21:38:42 +08:00
2024-09-30 08:12:47 +08:00
# Resume from weights if provided
2024-06-02 21:38:42 +08:00
if args . resume_adapter_file is not None :
2024-09-30 08:12:47 +08:00
print ( f " Loading fine-tuned weights from { args . resume_adapter_file } " )
2024-06-02 21:38:42 +08:00
model . load_weights ( args . resume_adapter_file , strict = False )
print_trainable_parameters ( model )
adapter_path = Path ( args . adapter_path )
adapter_path . mkdir ( parents = True , exist_ok = True )
2024-09-30 08:12:47 +08:00
2024-06-02 21:38:42 +08:00
adapter_file = adapter_path / " adapters.safetensors "
save_config ( vars ( args ) , adapter_path / " adapter_config.json " )
model . train ( )
2025-03-06 05:54:54 +08:00
# Initialize the selected optimizer
lr = build_schedule ( args . lr_schedule ) if args . lr_schedule else args . learning_rate
optimizer_name = args . optimizer . lower ( )
optimizer_config = args . optimizer_config . get ( optimizer_name , { } )
if optimizer_name == " adam " :
opt_class = optim . Adam
elif optimizer_name == " adamw " :
opt_class = optim . AdamW
else :
raise ValueError ( f " Unsupported optimizer: { optimizer_name } " )
opt = opt_class ( learning_rate = lr , * * optimizer_config )
2025-02-10 12:12:34 +08:00
2024-06-02 21:38:42 +08:00
# Train model
2025-02-01 04:10:44 +08:00
if args . training_mode == " grpo " :
2025-02-14 16:09:59 +08:00
train_model_grpo (
model ,
tokenizer ,
args ,
opt ,
train_set ,
valid_set ,
adapter_file ,
training_callback
2025-02-01 04:10:44 +08:00
)
else :
training_args = TrainingArgs (
batch_size = args . batch_size ,
iters = args . iters ,
val_batches = args . val_batches ,
steps_per_report = args . steps_per_report ,
steps_per_eval = args . steps_per_eval ,
steps_per_save = args . save_every ,
adapter_file = adapter_file ,
max_seq_length = args . max_seq_length ,
grad_checkpoint = args . grad_checkpoint
)
train (
model = model ,
tokenizer = tokenizer ,
args = training_args ,
optimizer = opt ,
train_dataset = train_set ,
val_dataset = valid_set ,
training_callback = training_callback ,
)
2024-06-02 21:38:42 +08:00
def evaluate_model ( args , model : nn . Module , tokenizer : TokenizerWrapper , test_set ) :
model . eval ( )
2025-02-01 04:10:44 +08:00
if args . training_mode == " grpo " :
if args . reference_model_path :
reference_model , _ = load ( args . reference_model_path )
else :
2025-02-09 22:30:51 +08:00
reference_model , _ = load ( args . model )
2025-02-01 04:10:44 +08:00
2025-02-05 15:53:30 +08:00
test_loss , _ , test_rewards = evaluate_grpo (
2025-02-01 04:10:44 +08:00
model = model ,
2025-02-09 22:30:51 +08:00
ref_model = reference_model . freeze ( ) ,
2025-02-01 04:10:44 +08:00
dataset = test_set ,
tokenizer = tokenizer ,
batch_size = args . batch_size ,
num_batches = args . test_batches ,
max_seq_length = args . max_seq_length ,
beta = args . beta ,
group_size = args . group_size ,
2025-02-21 23:02:27 +08:00
epsilon = args . epsilon ,
temperature = args . temperature ,
max_tokens = args . max_seq_length
2025-02-01 04:10:44 +08:00
)
2025-02-05 15:53:30 +08:00
test_ppl = math . exp ( test_loss )
2025-02-25 05:20:07 +08:00
rewards_str = " , " . join ( [ f " { k } : { v : .3f } " for k , v in test_rewards . items ( ) ] )
print ( f " Test loss { test_loss : .3f } , Test ppl { test_ppl : .3f } , Rewards: { rewards_str } " )
2025-02-01 04:10:44 +08:00
else :
test_loss = evaluate (
model = model ,
dataset = test_set ,
tokenizer = tokenizer ,
batch_size = args . batch_size ,
num_batches = args . test_batches ,
max_seq_length = args . max_seq_length ,
)
2024-03-14 11:26:30 +08:00
2025-02-01 04:10:44 +08:00
test_ppl = math . exp ( test_loss )
2024-06-02 21:38:42 +08:00
2025-02-01 04:10:44 +08:00
print ( f " Test loss { test_loss : .3f } , Test ppl { test_ppl : .3f } . " )
2024-06-02 21:38:42 +08:00
2024-03-14 11:26:30 +08:00
2024-02-27 11:35:04 +08:00
def run ( args , training_callback : TrainingCallback = None ) :
2024-01-24 00:44:37 +08:00
np . random . seed ( args . seed )
print ( " Loading pretrained model " )
model , tokenizer = load ( args . model )
2024-06-02 21:38:42 +08:00
print ( " Loading datasets " )
train_set , valid_set , test_set = load_dataset ( args , tokenizer )
2024-04-26 05:16:13 +08:00
if args . test and not args . train :
2024-05-16 23:21:26 +08:00
# Allow testing without LoRA layers by providing empty path
if args . adapter_path != " " :
2024-09-30 08:12:47 +08:00
load_adapters ( model , args . adapter_path )
2024-04-26 05:16:13 +08:00
2024-06-02 21:38:42 +08:00
elif args . train :
print ( " Training " )
train_model ( args , model , tokenizer , train_set , valid_set , training_callback )
2024-05-16 23:21:26 +08:00
else :
raise ValueError ( " Must provide at least one of --train or --test " )
2024-01-24 00:44:37 +08:00
if args . test :
print ( " Testing " )
2024-06-02 21:38:42 +08:00
evaluate_model ( args , model , tokenizer , test_set )
2024-01-24 00:44:37 +08:00
2024-02-27 11:35:04 +08:00
2024-04-17 07:08:49 +08:00
def main ( ) :
2025-01-04 02:50:59 +08:00
os . environ [ " TOKENIZERS_PARALLELISM " ] = " true "
2024-02-27 11:35:04 +08:00
parser = build_parser ( )
args = parser . parse_args ( )
2024-03-08 23:57:52 +08:00
config = args . config
args = vars ( args )
if config :
print ( " Loading configuration file " , config )
with open ( config , " r " ) as file :
config = yaml . load ( file , yaml_loader )
# Prefer parameters from command-line arguments
for k , v in config . items ( ) :
2024-05-22 11:09:35 +08:00
if args . get ( k , None ) is None :
2024-03-08 23:57:52 +08:00
args [ k ] = v
2024-02-27 11:35:04 +08:00
2024-03-08 23:57:52 +08:00
# Update defaults for unspecified parameters
for k , v in CONFIG_DEFAULTS . items ( ) :
2024-05-16 23:21:26 +08:00
if args . get ( k , None ) is None :
2024-03-08 23:57:52 +08:00
args [ k ] = v
run ( types . SimpleNamespace ( * * args ) )
2024-04-17 07:08:49 +08:00
if __name__ == " __main__ " :
2025-03-05 22:28:12 +08:00
main ( )
def compute_grpo_loss_and_grad (
model ,
ref_model ,
completion_tensors ,
prompt_texts ,
answer_texts ,
beta = 0.1 ,
epsilon = 1e-4 ,
reward_funcs = None ,
reward_weights = None
) :
"""
Compute GRPO loss and gradients using pre - generated completions .
Args :
model : The policy model
ref_model : The reference model
completion_tensors : List of tensors containing generated completions
prompt_texts : List of prompt texts
answer_texts : List of answer texts
beta : KL penalty coefficient
epsilon : Numerical stability constant
reward_funcs : List of reward functions
reward_weights : Optional weights for reward functions
"""
# Ensure model is in training mode for gradient computation
model . train ( )
# Get completion texts for reward calculation
completion_texts = [ tokenizer . decode ( comp . tolist ( ) ) for comp in completion_tensors ]
# Prepare inputs for loss computation
max_length = max ( tensor . shape [ 0 ] for tensor in completion_tensors )
padded_completions = [ ]
attention_masks = [ ]
for completion_tensor in completion_tensors :
padding_length = max_length - completion_tensor . shape [ 0 ]
if padding_length > 0 :
padding = mx . zeros ( ( padding_length , ) , dtype = completion_tensor . dtype )
padded_ids = mx . concatenate ( [ completion_tensor , padding ] )
mask = mx . concatenate (
[ mx . ones_like ( completion_tensor ) , mx . zeros_like ( padding ) ]
)
else :
padded_ids = completion_tensor
mask = mx . ones_like ( completion_tensor )
padded_completions . append ( padded_ids )
attention_masks . append ( mask )
inputs = mx . stack ( padded_completions )
attention_mask = mx . stack ( attention_masks )
lengths = attention_mask . sum ( axis = 1 )
# Compute log probabilities for both models
token_log_probs = get_per_token_logps ( model , inputs , lengths )
if ref_model is None :
ref_token_log_probs = [ mx . stop_gradient ( tlp ) for tlp in token_log_probs ]
else :
ref_token_log_probs = get_per_token_logps ( ref_model , inputs , lengths )
ref_token_log_probs = [ mx . stop_gradient ( tlp ) for tlp in ref_token_log_probs ]
# Pad log probabilities to same length
max_len = max ( x . shape [ 0 ] for x in token_log_probs )
padded_log_probs = [ ]
padded_ref_log_probs = [ ]
for i in range ( len ( token_log_probs ) ) :
seq_len = token_log_probs [ i ] . shape [ 0 ]
padding = mx . zeros ( ( max_len - seq_len , ) )
padded_log_probs . append ( mx . concatenate ( [ token_log_probs [ i ] , padding ] ) )
padded_ref_log_probs . append ( mx . concatenate ( [ ref_token_log_probs [ i ] , padding ] ) )
token_log_probs = mx . stack ( padded_log_probs )
ref_token_log_probs = mx . stack ( padded_ref_log_probs )
# Calculate rewards
all_func_rewards = [ ]
for reward_func in reward_funcs :
func_rewards = mx . array (
reward_func (
prompts = prompt_texts ,
completions = completion_texts ,
answer = answer_texts ,
)
)
all_func_rewards . append ( func_rewards )
# Stack rewards and apply weights
rewards = mx . stack ( all_func_rewards , axis = 1 )
if reward_weights is not None :
if len ( reward_weights ) != len ( reward_funcs ) :
raise ValueError (
f " Number of reward weights ( { len ( reward_weights ) } ) must match number of reward "
f " functions ( { len ( reward_funcs ) } ) "
)
reward_weights = mx . array ( reward_weights , dtype = mx . float32 )
else :
reward_weights = mx . ones ( len ( reward_funcs ) , dtype = mx . float32 )
rewards = ( rewards * mx . expand_dims ( reward_weights , 0 ) ) . sum ( axis = 1 )
# Group rewards by prompt (assuming completions are grouped by prompt)
group_size = len ( completion_tensors ) / / len ( prompt_texts )
if len ( completion_tensors ) % len ( prompt_texts ) != 0 :
raise ValueError ( " Number of completions must be divisible by number of prompts " )
rewards_by_group = [ ]
for i in range ( 0 , len ( rewards ) , group_size ) :
rewards_by_group . append ( rewards [ i : i + group_size ] )
# Calculate advantages
advantages = mx . zeros_like ( rewards )
for i , group_rewards in enumerate ( rewards_by_group ) :
if len ( group_rewards ) > 1 : # Only normalize if we have multiple samples
mean_reward = mx . mean ( group_rewards )
std_reward = mx . std ( group_rewards )
for j in range ( group_size ) :
idx = i * group_size + j
advantages [ idx ] = ( group_rewards [ j ] - mean_reward ) / ( std_reward + epsilon )
else :
# If only one sample, advantage is 0
advantages [ i * group_size ] = 0.0
# Compute KL divergence
kl_div = (
mx . exp ( ref_token_log_probs - token_log_probs )
- ( ref_token_log_probs - token_log_probs )
- 1
)
# Create mask for valid tokens
length_mask = mx . arange ( inputs . shape [ 1 ] - 1 ) [ None , : ] < ( lengths [ : , None ] - 1 )
# Compute policy ratio
policy_ratio = mx . exp ( token_log_probs - ref_token_log_probs )
# Compute per-token loss
per_token_loss = - (
( policy_ratio * advantages . reshape ( - 1 , 1 ) - beta * kl_div ) * length_mask
)
# Average over tokens
sequence_sums = per_token_loss . sum ( axis = 1 )
sequence_lengths = length_mask . sum ( axis = 1 )
loss = ( sequence_sums / sequence_lengths ) . mean ( )
# Calculate metrics for reporting
mean_kl = ( ( kl_div * length_mask ) . sum ( axis = 1 ) / length_mask . sum ( axis = 1 ) ) . mean ( )
metrics = {
" total_rewards_mean " : mx . mean ( rewards ) ,
" total_rewards_std " : mx . std ( rewards ) ,
" kl " : mean_kl ,
}
for i , reward_func in enumerate ( reward_funcs ) :
func_name = reward_func . __name__
func_rewards = all_func_rewards [ i ]
metrics [ f " { func_name } _mean " ] = mx . mean ( func_rewards )
metrics [ f " { func_name } _std " ] = mx . std ( func_rewards )
return loss , sequence_lengths . sum ( ) , metrics