minor speed improvement

This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 13:55:24 +01:00
parent 3dfb21267b
commit 2bde97fe13

View File

@ -129,7 +129,7 @@ def generate_step(
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
return y, logprobs.squeeze(0)
return mx.stop_gradient(y), mx.stop_gradient(logprobs.squeeze(0))
with mx.stream(generation_stream):
total_prompt_tokens = y.size
@ -186,7 +186,7 @@ def generate_grpo(
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
end_sequence = mx.array(tokenizer.encode(end_token))
results = []
mx.eval(expanded_prompts)
mx.eval(expanded_prompts, results)
print(f"Setup time: {time.time() - start_time:.2f}s")
print(f"Generating {total_samples} samples with max_tokens={max_tokens}")
@ -211,6 +211,7 @@ def generate_grpo(
sample_start_time = time.time()
current_tokens = []
prompt_cache = cache.make_prompt_cache(model)
mx.eval(current_tokens, prompt_cache)
# The generate_step function yields one token at a time
# We'll collect tokens until we hit max_tokens or a stopping condition
@ -316,6 +317,7 @@ def grpo_loss(
# Store original training state
was_training = model.training
print(f"Was model in training mode: {was_training}")
# Set model to eval mode for generation
model.eval()
@ -352,9 +354,8 @@ def grpo_loss(
for j, completion_ids in enumerate(completions):
# Calculate which prompt this completion belongs to
prompt_idx = i + (j // group_size)
if (
prompt_idx < total_samples
): # Make sure we don't go out of bounds
if prompt_idx < total_samples: # Make sure we don't go out of bounds
batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
@ -367,7 +368,6 @@ def grpo_loss(
# Restore original training state if we're not in validation mode
if not is_validation and was_training:
model.train()
mx.metal.clear_cache()
# If we didn't generate any completions, return early
@ -376,7 +376,6 @@ def grpo_loss(
"No completions were generated. Please check your model and inputs."
)
# The rest of the function remains the same
# Create expanded prompts and answers based on actual generated completions
expanded_answers = []
expanded_prompts = []