This commit is contained in:
Goekdeniz-Guelmez 2025-02-22 02:12:02 +01:00
parent 235348c211
commit d653371e3d

View File

@ -157,7 +157,6 @@ def generate_grpo(model: nn.Module, prompts, max_tokens, tokenizer, group_size,
current_input = expanded_prompts[idx]
prompt_cache = cache.make_prompt_cache(model)
# Initial forward pass with the prompt
logits = model(current_input[None], cache=prompt_cache)[:, -1]
while len(current_tokens) < max_tokens: