This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 14:49:56 +01:00
parent 9a36452519
commit d723ddfeda

View File

@ -343,6 +343,7 @@ def grpo_loss(
# Convert to tensor # Convert to tensor
prompt_tensor = mx.array(padded_prompts) prompt_tensor = mx.array(padded_prompts)
prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input
try: try:
mx.metal.clear_cache() mx.metal.clear_cache()