This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 14:40:23 +01:00
parent 2d2f39f96e
commit 326935be49

View File

@ -175,8 +175,6 @@ def generate_grpo(
try:
import time
model.freeze()
start_time = time.time()
if len(prompts.shape) == 1:
@ -213,7 +211,6 @@ def generate_grpo(
sample_start_time = time.time()
current_tokens = []
prompt_cache = cache.make_prompt_cache(model)
mx.eval(current_tokens, prompt_cache)
# The generate_step function yields one token at a time
# We'll collect tokens until we hit max_tokens or a stopping condition
@ -226,14 +223,13 @@ def generate_grpo(
prompt_cache=prompt_cache,
)
):
print(token)
# Check for EOS token
if token == tokenizer.eos_token_id:
break
current_tokens.append(token)
mx.eval(current_tokens[-1])
print(token)
# Check for end token
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
@ -245,6 +241,8 @@ def generate_grpo(
if i >= max_tokens - 1:
break
mx.eval(current_tokens)
if current_tokens:
results.append(mx.array(current_tokens))
total_tokens_generated += len(current_tokens)
@ -273,7 +271,6 @@ def generate_grpo(
results = [mx.stop_gradient(r) for r in results]
mx.eval(results)
model.unfreeze()
return results
except Exception as e:
@ -885,6 +882,8 @@ def train_grpo(
n_tokens += toks
steps += 1
mx.metal.clear_cache()
for k, v in metrics.items():
accumulated_metrics[k] += v