diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 9d2dfd42..9d9051f0 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -132,21 +132,15 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, is_training=False, end_token: str = "", temperature: float = 0.8): - if model.training == False: - print("Model is in training mode", model.training, "Manually setting to eval mode") - model.train() - if len(prompts.shape) == 1: prompts = prompts[None, :] if prompts.shape[1] == 0: return None - batch_size = prompts.shape[0] * group_size expanded_prompts = mx.repeat(prompts, group_size, axis=0) end_sequence = mx.array(tokenizer.encode(end_token)) - results = [] - + mx.eval(expanded_prompts) try: for idx in range(batch_size): current_tokens = [] @@ -154,8 +148,8 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, if is_training: current_input = expanded_prompts[idx] prompt_cache = cache.make_prompt_cache(model) - logits = model(current_input[None], cache=prompt_cache)[:, -1] + mx.eval(logits, prompt_cache) while len(current_tokens) < max_tokens: logits_temp = logits / temperature @@ -169,16 +163,16 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size, mx.array(test_sequence[-len(end_sequence):]), end_sequence )): + current_tokens.append(token) break if token == tokenizer.eos_token_id: break - + current_tokens.append(token) current_input = mx.array([token]) logits = model(current_input[None], cache=prompt_cache)[:, -1] - mx.eval(current_input) - mx.metal.clear_cache() + mx.eval(current_input, logits, probs, next_token) else: generator = generate_step( expanded_prompts[idx],