diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 4d79b1ac..edf4cf69 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -374,7 +374,9 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set max_seq_length=args.max_seq_length, beta=args.beta, group_size=args.group_size, - epsilon=args.epsilon + epsilon=args.epsilon, + temperature=args.temperature, + max_tokens=args.max_seq_length ) test_ppl = math.exp(test_loss) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index ea59ed06..f215c0ed 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -74,15 +74,14 @@ def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kw extracted_responses = [r1_extract_xml_answer(r) for r in completions] return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] - def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: if not completions: return [0.0] * len(prompts) - pattern = r".*?\s*.*?" - matches = [bool(re.search(pattern, r)) if r else False for r in completions] + has_think = r".*" + has_answer = r".*" + matches = [(bool(re.search(has_think, r)) and bool(re.search(has_answer, r))) if r else False for r in completions] return [0.5 if match else 0.0 for match in matches] - def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: if not completions: return [0.0] * len(prompts) @@ -114,44 +113,95 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li return scores -def generate_grpo(model: nn.Module, prompt, max_tokens, tokenizer, temperature): - if len(prompt.shape) == 1: - prompt = prompt[None, :] - if prompt.shape[1] == 0: +def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, temperature, group_size): + if len(prompts.shape) == 1: + prompts = prompts[None, :] + if prompts.shape[1] == 0: return None - end_sequence = tokenizer.encode("") - end_sequence_length = len(end_sequence) - initial_length = prompt.shape[1] - output = mx.zeros((initial_length + max_tokens,), dtype=mx.int32) - output[:initial_length] = prompt[0] - current_length = initial_length + + start_time = time.perf_counter() + tokens_generated = 0 + batch_size = prompts.shape[0] * group_size + + # Repeat each prompt group_size times + expanded_prompts = mx.repeat(prompts, group_size, axis=0) + + end_sequence = mx.array(tokenizer.encode("")) + end_len = len(end_sequence) + initial_length = prompts.shape[1] + + # Initialize output tensor for all sequences + output = mx.zeros((batch_size, initial_length + max_tokens), dtype=mx.int32) + output = mx.concatenate([expanded_prompts, mx.zeros((batch_size, max_tokens), dtype=mx.int32)], axis=1) + current_lengths = mx.array([initial_length] * batch_size) + + temp_factor = 1/temperature if temperature > 0 else float('inf') + try: - def sample(logits): - if temperature > 0: - logits /= temperature - logprobs = logits - mx.logsumexp(logits, keepdims=True) - return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0] + not_finished = mx.ones((batch_size,), dtype=mx.bool_) + for _ in range(max_tokens): - current_input = output[:current_length][None, :] - logits = model(current_input) - token_logits = logits[0, -1] - next_token = sample(token_logits) - token_value = next_token.item() - output[current_length] = token_value - current_length += 1 - if token_value == tokenizer.eos_token_id: + # Check if all sequences are finished + if not mx.sum(not_finished).item(): break - if current_length >= end_sequence_length: - last_tokens = output[current_length - end_sequence_length:current_length].tolist() - if last_tokens == end_sequence: - break - if current_length > initial_length: - return output[:current_length] + + # Get model outputs for all sequences + max_len = mx.max(current_lengths).item() + batch_inputs = output[:, :max_len] + logits = model(batch_inputs)[:, -1] + + # Apply mask to logits + logits = logits * mx.expand_dims(not_finished, -1) + + # Sample next tokens + logits *= temp_factor + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + next_tokens = mx.random.categorical(logprobs) + + # Update outputs for active sequences + for idx in range(batch_size): + if not_finished[idx].item(): + curr_len = current_lengths[idx].item() + token_value = next_tokens[idx].item() + + # Create new arrays with updates + output = mx.array(output.tolist()) # Make a copy + output[idx, curr_len] = token_value + current_lengths = mx.array([ + l + 1 if i == idx else l + for i, l in enumerate(current_lengths.tolist()) + ]) + tokens_generated += 1 + + # Check end conditions + if token_value == tokenizer.eos_token_id: + not_finished = mx.array([ + False if i == idx else nf + for i, nf in enumerate(not_finished.tolist()) + ]) + elif curr_len >= end_len: + last_tokens = output[idx, curr_len-end_len+1:curr_len+1] + if mx.array_equal(last_tokens, end_sequence): + not_finished = mx.array([ + False if i == idx else nf + for i, nf in enumerate(not_finished.tolist()) + ]) + + if _ % 32 == 0: + mx.eval(output, current_lengths, not_finished) + + end_time = time.perf_counter() + generation_time = end_time - start_time + tokens_per_second = tokens_generated / generation_time + print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)") + + # Return only the valid part of each sequence + results = [output[i, :current_lengths[i].item()] for i in range(batch_size)] + return results + except Exception as e: print(f"Generation error: {str(e)}") return None - - return None def get_per_token_logps(model: nn.Module, inputs, lengths): @@ -185,7 +235,8 @@ def grpo_loss( epsilon=1e-4, max_tokens=64, temperature=1.0, - reward_weights=None + reward_weights=None, + is_validation=False ): prompt_tokens, _, prompt_text, answer_text = batch batch_size = len(prompt_tokens) @@ -195,22 +246,27 @@ def grpo_loss( for i in range(0, batch_size, batch_size): batch_prompts = prompt_tokens[i:i+batch_size] - for prompt in batch_prompts: - prompt_tensor = mx.array(prompt) - for _ in range(group_size): - try: - completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature) - if completion_ids is not None: - completion_text = tokenizer.decode(completion_ids.tolist()) - all_completions.append(completion_ids) - all_completion_texts.append(completion_text) - mx.eval(completion_ids) - del completion_ids - except Exception as e: - print(f"Generation error: {e}") - continue - - mx.metal.clear_cache() + prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts]) + + try: + completions = generate_grpo( + model, + prompt_tensor, + max_tokens, + tokenizer, + temperature, + group_size + ) + + if completions is not None: + for completion_ids in completions: + completion_text = tokenizer.decode(completion_ids.tolist()) + all_completions.append(completion_ids) + all_completion_texts.append(completion_text) + mx.eval(completion_ids) + except Exception as e: + print(f"Generation error: {e}") + continue expanded_answers = [] expanded_prompts = [] @@ -242,15 +298,12 @@ def grpo_loss( token_log_probs = get_per_token_logps(model, inputs, lengths) mx.eval(token_log_probs) - mx.metal.clear_cache() - - # Reference policy probabilities + if ref_model is None: ref_token_log_probs = token_log_probs else: ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) mx.eval(ref_token_log_probs) - mx.metal.clear_cache() max_len = max(x.shape[0] for x in token_log_probs) padded_log_probs = [] @@ -339,6 +392,10 @@ def grpo_loss( 'kl': mean_kl, **reward_metrics } + + if is_validation: + print(f"\nValidation sample generation:\n{all_completion_texts[-1]}\n") + mx.metal.clear_cache() return loss, sequence_lengths.sum(), metrics @@ -412,7 +469,7 @@ def evaluate_grpo( ): all_losses = 0 ntokens = 0 - all_metrics = None + all_metrics = None index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -434,7 +491,8 @@ def evaluate_grpo( epsilon=epsilon, ref_model=ref_model, temperature=temperature, - max_tokens=max_tokens + max_tokens=max_tokens, + is_validation=True ) all_losses += losses * toks