mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
updates
This commit is contained in:
@@ -343,6 +343,7 @@ def grpo_loss(
|
|||||||
|
|
||||||
# Convert to tensor
|
# Convert to tensor
|
||||||
prompt_tensor = mx.array(padded_prompts)
|
prompt_tensor = mx.array(padded_prompts)
|
||||||
|
prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|||||||
Reference in New Issue
Block a user