mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
minor speed improvement
This commit is contained in:
parent
3dfb21267b
commit
2bde97fe13
@ -129,7 +129,7 @@ def generate_step(
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs.squeeze(0)
|
||||
return mx.stop_gradient(y), mx.stop_gradient(logprobs.squeeze(0))
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
total_prompt_tokens = y.size
|
||||
@ -186,7 +186,7 @@ def generate_grpo(
|
||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||
results = []
|
||||
mx.eval(expanded_prompts)
|
||||
mx.eval(expanded_prompts, results)
|
||||
|
||||
print(f"Setup time: {time.time() - start_time:.2f}s")
|
||||
print(f"Generating {total_samples} samples with max_tokens={max_tokens}")
|
||||
@ -211,6 +211,7 @@ def generate_grpo(
|
||||
sample_start_time = time.time()
|
||||
current_tokens = []
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
mx.eval(current_tokens, prompt_cache)
|
||||
|
||||
# The generate_step function yields one token at a time
|
||||
# We'll collect tokens until we hit max_tokens or a stopping condition
|
||||
@ -316,6 +317,7 @@ def grpo_loss(
|
||||
|
||||
# Store original training state
|
||||
was_training = model.training
|
||||
print(f"Was model in training mode: {was_training}")
|
||||
|
||||
# Set model to eval mode for generation
|
||||
model.eval()
|
||||
@ -352,9 +354,8 @@ def grpo_loss(
|
||||
for j, completion_ids in enumerate(completions):
|
||||
# Calculate which prompt this completion belongs to
|
||||
prompt_idx = i + (j // group_size)
|
||||
if (
|
||||
prompt_idx < total_samples
|
||||
): # Make sure we don't go out of bounds
|
||||
|
||||
if prompt_idx < total_samples: # Make sure we don't go out of bounds
|
||||
batch_indices.append(prompt_idx)
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
@ -367,7 +368,6 @@ def grpo_loss(
|
||||
# Restore original training state if we're not in validation mode
|
||||
if not is_validation and was_training:
|
||||
model.train()
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# If we didn't generate any completions, return early
|
||||
@ -376,7 +376,6 @@ def grpo_loss(
|
||||
"No completions were generated. Please check your model and inputs."
|
||||
)
|
||||
|
||||
# The rest of the function remains the same
|
||||
# Create expanded prompts and answers based on actual generated completions
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
|
Loading…
Reference in New Issue
Block a user