diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 28aa3420..a5eb8ffe 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -43,6 +43,7 @@ yaml_loader.add_implicit_resolver( CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, + "training_mode": "normal", "fine_tune_type": "lora", "data": "data/", "seed": 0, @@ -62,6 +63,10 @@ CONFIG_DEFAULTS = { "config": None, "grad_checkpoint": False, "lr_schedule": None, + "reference_model_path": None, + "group_size": 4, + "beta": 0.1, + "epsilon": 1e-4, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } @@ -95,6 +100,12 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) + parser.add_argument( + "--training-mode", + type=str, + choices=["normal", "grpo"], + help="Training mode: normal or GRPO", + ) parser.add_argument( "--num-layers", type=int, @@ -162,6 +173,25 @@ def build_parser(): default=None, ) parser.add_argument("--seed", type=int, help="The PRNG seed") + + parser.add_argument( + "--group-size", + type=int, + help="Number of responses per prompt.", + default=4, + ) + parser.add_argument( + "--beta", + type=float, + help="KL penalty coefficient.", + default=0.1, + ) + parser.add_argument( + "--epsilon", + type=float, + help="The Epsilon for numerical stability.", + default=1e-4, + ) return parser @@ -221,32 +251,98 @@ def train_model( ) ) # Train model - train( - model=model, - tokenizer=tokenizer, - args=training_args, - optimizer=opt, - train_dataset=train_set, - val_dataset=valid_set, - training_callback=training_callback, - ) + if args.training_mode == "grpo": + training_args = GRPOTrainingArgs( + batch_size=args.batch_size, + iters=args.iters, + val_batches=args.val_batches, + steps_per_report=args.steps_per_report, + steps_per_eval=args.steps_per_eval, + steps_per_save=args.save_every, + adapter_file=adapter_file, + max_seq_length=args.max_seq_length, + grad_checkpoint=args.grad_checkpoint, + beta=args.beta, + group_size=args.group_size, + epsilon=args.epsilon, + reference_model_path=args.reference_model_path + ) + + if args.reference_model_path: + reference_model, _ = load(args.reference_model_path) + else: + reference_model, _ = load(args.model) + + train_grpo( + model=model, + reference_model=reference_model.freeze(), + tokenizer=tokenizer, + optimizer=opt, + train_dataset=train_set, + val_dataset=valid_set, + args=training_args, + training_callback=training_callback, + ) + else: + training_args = TrainingArgs( + batch_size=args.batch_size, + iters=args.iters, + val_batches=args.val_batches, + steps_per_report=args.steps_per_report, + steps_per_eval=args.steps_per_eval, + steps_per_save=args.save_every, + adapter_file=adapter_file, + max_seq_length=args.max_seq_length, + grad_checkpoint=args.grad_checkpoint + ) + + train( + model=model, + tokenizer=tokenizer, + args=training_args, + optimizer=opt, + train_dataset=train_set, + val_dataset=valid_set, + training_callback=training_callback, + ) def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set): model.eval() - test_loss = evaluate( - model=model, - dataset=test_set, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.test_batches, - max_seq_length=args.max_seq_length, - ) + if args.training_mode == "grpo": + if args.reference_model_path: + reference_model, _ = load(args.reference_model_path) + else: + reference_model = model - test_ppl = math.exp(test_loss) + test_loss, test_rewards = evaluate_grpo( + model=model, + reference_model=reference_model, + dataset=test_set, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.test_batches, + max_seq_length=args.max_seq_length, + beta=args.beta, + group_size=args.group_size, + epsilon=args.epsilon, + reference_model_path=args.reference_model_path + ) + print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") + else: + test_loss = evaluate( + model=model, + dataset=test_set, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.test_batches, + max_seq_length=args.max_seq_length, + ) - print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") + test_ppl = math.exp(test_loss) + + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") def run(args, training_callback: TrainingCallback = None): @@ -297,4 +393,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 720533b1..c3b3007e 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -22,13 +22,7 @@ generate() class GRPOTrainingArgs(TrainingArgs): group_size: int = field( default=4, - metadata={"help": "Number of response sper prompt."}, - ) - is_reference_free: bool = field( - default=False, - metadata={ - "help": "Whether to use reference-free DPO training." - } + metadata={"help": "Number of responses per prompt."}, ) beta: float = field( default=0.1, metadata={"help": "KL penalty coefficient."}