From 717384028366d5020e3aee29a380c1bef9affa2f Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 4 Feb 2025 09:18:45 +0100 Subject: [PATCH] first succesfull training run --- llms/mlx_lm/lora.py | 30 ++++++++++-- llms/mlx_lm/tuner/datasets.py | 25 ++++++++-- llms/mlx_lm/tuner/grpo_trainer.py | 79 ++++++++----------------------- 3 files changed, 68 insertions(+), 66 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 1f684d27..1e4fe3d5 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -63,11 +63,16 @@ CONFIG_DEFAULTS = { "config": None, "grad_checkpoint": False, "lr_schedule": None, + "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + + # GRPO args "reference_model_path": None, "group_size": 4, "beta": 0.1, "epsilon": 1e-4, - "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + "max_completion_length": 512, + "use_chat_template": False, + "use_prompt": False, } @@ -178,9 +183,15 @@ def build_parser(): parser.add_argument( "--group-size", type=int, - help="Number of responses per prompt.", + help="Number of generations.", default=4, ) + parser.add_argument( + "--max-completion-length", + type=int, + help="Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.", + default=512, + ) parser.add_argument( "--beta", type=float, @@ -193,6 +204,18 @@ def build_parser(): help="The Epsilon for numerical stability.", default=1e-4, ) + parser.add_argument( + "--use-chat-template", + type=bool, + help="If the model is a Chat model, use the Chat template.", + default=False, + ) + parser.add_argument( + "--use-prompt", + type=bool, + help="Rather to use the prompt from teh R1 paper.", + default=False, + ) return parser @@ -262,6 +285,7 @@ def train_model( steps_per_save=args.save_every, adapter_file=adapter_file, max_seq_length=args.max_seq_length, + max_completion_length=args.max_completion_length, grad_checkpoint=args.grad_checkpoint, beta=args.beta, group_size=args.group_size, @@ -273,7 +297,7 @@ def train_model( reference_model, _ = load(args.reference_model_path) reference_model = reference_model.freeze() else: - reference_model, _ = None, None + reference_model, _ = load(args.model) train_grpo( model=model, diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 163b1be7..b31656c6 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -16,14 +16,33 @@ class GRPODataset: data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, prompt_key: str = "prompt", - answer_key: str = "answer" + answer_key: str = "answer", + use_chat_template: bool = False, + use_prompt: bool = False ): self._data = [] for item in data: prompt_str = str(item[prompt_key]) answer_str = str(item[answer_key]) - prompt_tokens = tokenizer.encode(prompt_str) - answer_tokens = tokenizer.encode(answer_str) + if use_chat_template: + prompt_tokens = tokenizer.apply_chat_template( + [ + {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. + The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. + The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here ."""}, + {'role': 'user', 'content': prompt_str} + ], + ) + answer_tokens = tokenizer.encode(answer_str) + else: + if use_prompt: + prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. + The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. + The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . + User: {prompt_str}. Assistant: """) + else: + prompt_tokens = tokenizer.encode(prompt_str) + answer_tokens = tokenizer.encode(answer_str) self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str)) def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]: diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 29518d8f..3b5dc2e9 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -12,8 +12,6 @@ from mlx.utils import tree_flatten from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches -from mlx_lm.utils import generate_step - @dataclass class GRPOTrainingArgs(TrainingArgs): @@ -27,6 +25,9 @@ class GRPOTrainingArgs(TrainingArgs): epsilon: float = field( default=1e-4, metadata={"help": "The Epsilon for numerical stability."} ) + max_completion_length: int = field( + default=512, metadata={"help": "Number of Generations."} + ) reference_model_path: str = field( default=None, metadata={ @@ -36,7 +37,6 @@ class GRPOTrainingArgs(TrainingArgs): def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): - model.eval() if len(prompt.shape) == 1: prompt = prompt[None, :] @@ -58,11 +58,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): token_value = next_token.item() generated.append(next_token) - - # Clear intermediate tensors - del logits, token_logits, probs - mx.metal.clear_cache() - + current_prompt = mx.concatenate([current_prompt, next_token[None]]) if token_value == tokenizer.eos_token_id: break @@ -72,12 +68,6 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): result = mx.concatenate([prompt[0], mx.stack(generated)]) mx.eval(result) - model.train() - - # Clear generated tokens - del generated - mx.metal.clear_cache() - return result @@ -192,11 +182,6 @@ def get_per_token_logps(model, inputs, lengths): ).squeeze(-1) # [seq_len] per_token_logps.append(token_log_probs) - - # Clean up intermediates - del seq_logits, seq_targets, log_probs, token_log_probs - mx.metal.clear_cache() - return per_token_logps @@ -232,15 +217,9 @@ def grpo_loss( all_completions.append(completion_ids) all_completion_texts.append(completion_text) - del completion_ids - mx.metal.clear_cache() - except Exception as e: print(f"Generation error: {e}") continue - - del prompt_tensor - mx.metal.clear_cache() # Prepare inputs expanded_answers = [] @@ -264,25 +243,11 @@ def grpo_loss( mask = mx.ones_like(completion_ids) padded_completions.append(padded_ids) attention_masks.append(mask) - - del completion_ids - if padding_length > 0: - del padding - del mask - mx.metal.clear_cache() inputs = mx.stack(padded_completions) attention_mask = mx.stack(attention_masks) lengths = attention_mask.sum(axis=1) - del padded_completions, attention_masks - mx.metal.clear_cache() - - # Get logits and compute log probabilities - logits = model(inputs).astype(mx.float32) - log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1) - targets = inputs[:, 1:] - # Current policy probabilities token_log_probs = get_per_token_logps(model, inputs, lengths) @@ -302,9 +267,6 @@ def grpo_loss( padded_log_probs.append(mx.concatenate([token_log_probs[i], padding])) padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding])) - - del padding - mx.metal.clear_cache() token_log_probs = mx.stack(padded_log_probs) ref_token_log_probs = mx.stack(padded_ref_log_probs) @@ -360,10 +322,6 @@ def grpo_loss( reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) - # Clean up - del all_completions - mx.metal.clear_cache() - metrics = { 'total_rewards_mean': mx.mean(rewards), 'total_rewards_std': mx.std(rewards), @@ -440,7 +398,7 @@ def evaluate_grpo( group_size: int, max_seq_length, reward_funcs = None, - loss: callable = grpo_loss, + loss_fn: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches ): """ @@ -466,7 +424,7 @@ def evaluate_grpo( ), ): # Calculate loss for current batch - losses, toks, metrics = loss( + losses, toks, metrics = loss_fn( model=model, tokenizer=tokenizer, batch=batch, @@ -518,7 +476,7 @@ def train_grpo( r1_count_xml ], args: GRPOTrainingArgs = GRPOTrainingArgs(), - loss: callable = grpo_loss, + loss_fn: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches, training_callback: TrainingCallback = None, ): @@ -546,7 +504,7 @@ def train_grpo( group_size=args.group_size, epsilon=args.epsilon, ref_model=ref_model, - max_tokens=args.max_seq_length, + max_tokens=args.max_completion_length, ) # All reduce the gradients if running in distributed mode @@ -557,22 +515,23 @@ def train_grpo( return loss, toks, metrics - loss_value_and_grad = nn.value_and_grad(model, loss) + loss_value_and_grad = nn.value_and_grad(model, loss_fn) losses = 0 n_tokens = 0 steps = 0 trained_tokens = 0 accumulated_metrics = { - 'rewards': 0, - 'rewards_std': 0, - 'grouped_rewards': 0, + 'total_rewards_mean': 0, + 'total_rewards_std': 0, + 'grouped_rewards_mean': 0, 'grouped_rewards_std': 0, 'kl': 0 } - for i in range(len(reward_funcs)): - accumulated_metrics[f'reward_func_{i}_mean'] = 0 - accumulated_metrics[f'reward_func_{i}_std'] = 0 + for reward_func in reward_funcs: + func_name = reward_func.__name__ + accumulated_metrics[f'{func_name}_mean'] = 0 + accumulated_metrics[f'{func_name}_std'] = 0 start = time.perf_counter() for it, batch in zip( @@ -592,7 +551,7 @@ def train_grpo( val_loss, val_ntokens, val_metrics = evaluate_grpo( model=model, dataset=val_dataset, - loss=loss, + loss_fn=loss_fn, ref_model=ref_model, reward_funcs=reward_funcs, tokenizer=tokenizer, @@ -675,8 +634,8 @@ def train_grpo( for i, reward_func in enumerate(reward_funcs): func_name = reward_func.__name__ train_metrics_str += ( - f", Reward func {reward_func.__name__} mean {avg_metrics[f'reward_func_{reward_func.__name__}_mean']:.3f}, " - f"Reward func {reward_func.__name__} std {avg_metrics[f'reward_func_{reward_func.__name__}_std']:.3f}" + f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, " + f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}" ) print(