From baeb9f117f0107c1064e059cbc91df9979145ade Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 14 Feb 2025 09:09:59 +0100 Subject: [PATCH] reduncancy fix + nits --- llms/mlx_lm/lora.py | 86 +++++++++++++++++------------------ llms/mlx_lm/tuner/datasets.py | 9 +--- 2 files changed, 44 insertions(+), 51 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 4751123c..4bb39832 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -226,6 +226,39 @@ def build_parser(): ) return parser +def train_model_grpo(model, tokenizer, args, opt, train_set, valid_set, adapter_file, training_callback): + training_args = GRPOTrainingArgs( + batch_size=args.batch_size, + iters=args.iters, + val_batches=args.val_batches, + steps_per_report=args.steps_per_report, + steps_per_eval=args.steps_per_eval, + steps_per_save=args.save_every, + adapter_file=adapter_file, + max_seq_length=args.max_seq_length, + max_completion_length=args.max_completion_length, + grad_checkpoint=args.grad_checkpoint, + beta=args.beta, + group_size=args.group_size, + epsilon=args.epsilon, + reference_model_path=args.reference_model_path + ) + + if args.reference_model_path: + reference_model, _ = load(args.reference_model_path) + else: + reference_model, _ = load(args.model) + + train_grpo( + model=model, + ref_model=reference_model.freeze(), + tokenizer=tokenizer, + optimizer=opt, + train_dataset=train_set, + val_dataset=valid_set, + args=training_args, + training_callback=training_callback, + ) def train_model( args, @@ -263,19 +296,6 @@ def train_model( adapter_file = adapter_path / "adapters.safetensors" save_config(vars(args), adapter_path / "adapter_config.json") - # init training args - training_args = TrainingArgs( - batch_size=args.batch_size, - iters=args.iters, - val_batches=args.val_batches, - steps_per_report=args.steps_per_report, - steps_per_eval=args.steps_per_eval, - steps_per_save=args.save_every, - adapter_file=adapter_file, - max_seq_length=args.max_seq_length, - grad_checkpoint=args.grad_checkpoint, - ) - model.train() opt = optim.Adam( learning_rate=( @@ -285,37 +305,15 @@ def train_model( # Train model if args.training_mode == "grpo": - training_args = GRPOTrainingArgs( - batch_size=args.batch_size, - iters=args.iters, - val_batches=args.val_batches, - steps_per_report=args.steps_per_report, - steps_per_eval=args.steps_per_eval, - steps_per_save=args.save_every, - adapter_file=adapter_file, - max_seq_length=args.max_seq_length, - max_completion_length=args.max_completion_length, - grad_checkpoint=args.grad_checkpoint, - beta=args.beta, - group_size=args.group_size, - epsilon=args.epsilon, - reference_model_path=args.reference_model_path - ) - - if args.reference_model_path: - reference_model, _ = load(args.reference_model_path) - else: - reference_model, _ = load(args.model) - - train_grpo( - model=model, - ref_model=reference_model.freeze(), - tokenizer=tokenizer, - optimizer=opt, - train_dataset=train_set, - val_dataset=valid_set, - args=training_args, - training_callback=training_callback, + train_model_grpo( + model, + tokenizer, + args, + opt, + train_set, + valid_set, + adapter_file, + training_callback ) else: training_args = TrainingArgs( diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index df5c4588..678c4f41 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -29,9 +29,7 @@ class GRPODataset: if use_chat_template: prompt_tokens = tokenizer.apply_chat_template( [ - {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. - The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. - The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""}, + {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""}, {'role': 'user', 'content': prompt_str} ], add_generation_prompt=True @@ -39,10 +37,7 @@ class GRPODataset: answer_tokens = tokenizer.encode(answer_str) else: if use_prompt: - prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. - The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. - The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . - User: {prompt_str} Assistant: """) + prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . User: {prompt_str} Assistant: """) else: prompt_tokens = tokenizer.encode(prompt_str) answer_tokens = tokenizer.encode(answer_str)