mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
minor speed improvement
This commit is contained in:
parent
3dfb21267b
commit
2bde97fe13
@ -129,7 +129,7 @@ def generate_step(
|
|||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
y = sampler(logprobs)
|
y = sampler(logprobs)
|
||||||
return y, logprobs.squeeze(0)
|
return mx.stop_gradient(y), mx.stop_gradient(logprobs.squeeze(0))
|
||||||
|
|
||||||
with mx.stream(generation_stream):
|
with mx.stream(generation_stream):
|
||||||
total_prompt_tokens = y.size
|
total_prompt_tokens = y.size
|
||||||
@ -186,7 +186,7 @@ def generate_grpo(
|
|||||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||||
results = []
|
results = []
|
||||||
mx.eval(expanded_prompts)
|
mx.eval(expanded_prompts, results)
|
||||||
|
|
||||||
print(f"Setup time: {time.time() - start_time:.2f}s")
|
print(f"Setup time: {time.time() - start_time:.2f}s")
|
||||||
print(f"Generating {total_samples} samples with max_tokens={max_tokens}")
|
print(f"Generating {total_samples} samples with max_tokens={max_tokens}")
|
||||||
@ -211,6 +211,7 @@ def generate_grpo(
|
|||||||
sample_start_time = time.time()
|
sample_start_time = time.time()
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
prompt_cache = cache.make_prompt_cache(model)
|
prompt_cache = cache.make_prompt_cache(model)
|
||||||
|
mx.eval(current_tokens, prompt_cache)
|
||||||
|
|
||||||
# The generate_step function yields one token at a time
|
# The generate_step function yields one token at a time
|
||||||
# We'll collect tokens until we hit max_tokens or a stopping condition
|
# We'll collect tokens until we hit max_tokens or a stopping condition
|
||||||
@ -316,6 +317,7 @@ def grpo_loss(
|
|||||||
|
|
||||||
# Store original training state
|
# Store original training state
|
||||||
was_training = model.training
|
was_training = model.training
|
||||||
|
print(f"Was model in training mode: {was_training}")
|
||||||
|
|
||||||
# Set model to eval mode for generation
|
# Set model to eval mode for generation
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -352,9 +354,8 @@ def grpo_loss(
|
|||||||
for j, completion_ids in enumerate(completions):
|
for j, completion_ids in enumerate(completions):
|
||||||
# Calculate which prompt this completion belongs to
|
# Calculate which prompt this completion belongs to
|
||||||
prompt_idx = i + (j // group_size)
|
prompt_idx = i + (j // group_size)
|
||||||
if (
|
|
||||||
prompt_idx < total_samples
|
if prompt_idx < total_samples: # Make sure we don't go out of bounds
|
||||||
): # Make sure we don't go out of bounds
|
|
||||||
batch_indices.append(prompt_idx)
|
batch_indices.append(prompt_idx)
|
||||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||||
all_completions.append(completion_ids)
|
all_completions.append(completion_ids)
|
||||||
@ -367,7 +368,6 @@ def grpo_loss(
|
|||||||
# Restore original training state if we're not in validation mode
|
# Restore original training state if we're not in validation mode
|
||||||
if not is_validation and was_training:
|
if not is_validation and was_training:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
# If we didn't generate any completions, return early
|
# If we didn't generate any completions, return early
|
||||||
@ -376,7 +376,6 @@ def grpo_loss(
|
|||||||
"No completions were generated. Please check your model and inputs."
|
"No completions were generated. Please check your model and inputs."
|
||||||
)
|
)
|
||||||
|
|
||||||
# The rest of the function remains the same
|
|
||||||
# Create expanded prompts and answers based on actual generated completions
|
# Create expanded prompts and answers based on actual generated completions
|
||||||
expanded_answers = []
|
expanded_answers = []
|
||||||
expanded_prompts = []
|
expanded_prompts = []
|
||||||
|
Loading…
Reference in New Issue
Block a user