mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
updates
This commit is contained in:
parent
1f89453295
commit
2d2f39f96e
@ -175,6 +175,8 @@ def generate_grpo(
|
|||||||
try:
|
try:
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
model.freeze()
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if len(prompts.shape) == 1:
|
if len(prompts.shape) == 1:
|
||||||
@ -229,8 +231,9 @@ def generate_grpo(
|
|||||||
# Check for EOS token
|
# Check for EOS token
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
|
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
|
mx.eval(current_tokens[-1])
|
||||||
|
|
||||||
# Check for end token
|
# Check for end token
|
||||||
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
|
||||||
@ -268,7 +271,9 @@ def generate_grpo(
|
|||||||
)
|
)
|
||||||
print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec")
|
print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec")
|
||||||
|
|
||||||
|
results = [mx.stop_gradient(r) for r in results]
|
||||||
mx.eval(results)
|
mx.eval(results)
|
||||||
|
model.unfreeze()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -307,6 +312,7 @@ def grpo_loss(
|
|||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
reward_weights: Optional[List[float]] = None,
|
reward_weights: Optional[List[float]] = None,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
|
is_validation: bool = False
|
||||||
):
|
):
|
||||||
prompt_tokens, _, prompt_text, answer_text = batch
|
prompt_tokens, _, prompt_text, answer_text = batch
|
||||||
total_samples = len(prompt_tokens)
|
total_samples = len(prompt_tokens)
|
||||||
@ -321,6 +327,7 @@ def grpo_loss(
|
|||||||
|
|
||||||
# Set model to eval mode for generation
|
# Set model to eval mode for generation
|
||||||
model.eval()
|
model.eval()
|
||||||
|
print(f"Is model now in training mode: {model.training}")
|
||||||
|
|
||||||
# Process in smaller batches
|
# Process in smaller batches
|
||||||
for i in range(0, total_samples, batch_size):
|
for i in range(0, total_samples, batch_size):
|
||||||
@ -340,6 +347,7 @@ def grpo_loss(
|
|||||||
prompt_tensor = mx.array(padded_prompts)
|
prompt_tensor = mx.array(padded_prompts)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
mx.metal.clear_cache()
|
||||||
completions = generate_grpo(
|
completions = generate_grpo(
|
||||||
model,
|
model,
|
||||||
prompt_tensor,
|
prompt_tensor,
|
||||||
@ -363,6 +371,7 @@ def grpo_loss(
|
|||||||
mx.eval(completion_ids)
|
mx.eval(completion_ids)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Generation error: {e}")
|
print(f"Generation error: {e}")
|
||||||
|
print(f"Is model in training mode after generation: {model.training}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Restore original training state if we're not in validation mode
|
# Restore original training state if we're not in validation mode
|
||||||
@ -712,7 +721,8 @@ def evaluate_grpo(
|
|||||||
epsilon=epsilon,
|
epsilon=epsilon,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens
|
max_tokens=max_tokens,
|
||||||
|
is_validation=True
|
||||||
)
|
)
|
||||||
|
|
||||||
all_losses += losses * toks
|
all_losses += losses * toks
|
||||||
|
Loading…
Reference in New Issue
Block a user