From 710bc1490e7f9907679627e2348af3ade182ccbb Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 21 Feb 2025 22:42:15 +0100 Subject: [PATCH] training mode working too got from 2 toks/sec to 30 toks/sec with raw 1.5B model --- llms/mlx_lm/tuner/grpo_trainer.py | 98 ++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 28 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index d8ab0fa5..3e581d13 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -112,51 +112,90 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li return scores -def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size): +def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False): if len(prompts.shape) == 1: prompts = prompts[None, :] if prompts.shape[1] == 0: return None - model.eval() batch_size = prompts.shape[0] * group_size expanded_prompts = mx.repeat(prompts, group_size, axis=0) + mx.eval(expanded_prompts) results = [] tokens_generated = 0 start_time = time.perf_counter() - try: - for idx in range(batch_size): - current_tokens = [] - generator = generate_step( - expanded_prompts[idx], - model, - max_tokens=max_tokens, - sampler=lambda x: mx.argmax(x, axis=-1) - ) + for idx in range(batch_size): + current_prompt = expanded_prompts[idx:idx+1] + mx.eval(current_prompt) + + current_tokens = [] + try: + if is_training: + # Initialize with prompt + current_input = current_prompt[0] + mx.eval(current_input) + + while len(current_tokens) < max_tokens: + # Generate one token at a time + logits = model(current_input[None]) + next_token = mx.random.categorical(logits[:, -1, :]) + token = next_token.item() + current_tokens.append(token) + tokens_generated += 1 + + # Clear intermediate results + mx.eval(next_token) + del logits + + if token == tokenizer.eos_token_id: + break + + # Update input for next iteration + current_input = mx.array([token]) + mx.eval(current_input) + + # Clear cache periodically + if len(current_tokens) % 8 == 0: + mx.metal.clear_cache() + else: + generator = generate_step( + current_prompt[0], + model, + max_tokens=max_tokens, + sampler=lambda x: mx.random.categorical(x) + ) + + for token, _ in generator: + current_tokens.append(token) + tokens_generated += 1 + if token == tokenizer.eos_token_id: + break - # Collect all tokens first - for tokens, _ in generator: - current_tokens.append(tokens) - tokens_generated += 1 - if tokens == tokenizer.eos_token_id: - break + if current_tokens: + token_array = mx.array(current_tokens) + mx.eval(token_array) + results.append(token_array) + del token_array - # Convert to array after collection - results.append(mx.array(current_tokens)) - mx.metal.clear_cache() + except Exception as e: + print(f"Generation failed for sequence {idx}: {e}") + continue - # Final evaluation of all results - mx.eval(results) - generation_time = time.perf_counter() - start_time - print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)") - return results + mx.metal.clear_cache() - except Exception as e: - print(f"Generation error: {str(e)}") + if not results: + print("No successful generations") return None + mx.eval(results) + + generation_time = time.perf_counter() - start_time + print(f"Generated {tokens_generated} tokens in {generation_time:.2f}s ({tokens_generated/generation_time:.2f} tokens/s)") + + return results + def get_per_token_logps(model: nn.Module, inputs, lengths): logits = model(inputs).astype(mx.float16) @@ -209,7 +248,8 @@ def grpo_loss( prompt_tensor, max_tokens, tokenizer, - group_size + group_size, + True ) if completions is not None: @@ -221,6 +261,8 @@ def grpo_loss( except Exception as e: print(f"Generation error: {e}") continue + + mx.metal.clear_cache() expanded_answers = [] expanded_prompts = []