mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-11 19:56:42 +08:00
updates
This commit is contained in:
parent
9a36452519
commit
d723ddfeda
@ -343,6 +343,7 @@ def grpo_loss(
|
|||||||
|
|
||||||
# Convert to tensor
|
# Convert to tensor
|
||||||
prompt_tensor = mx.array(padded_prompts)
|
prompt_tensor = mx.array(padded_prompts)
|
||||||
|
prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
Loading…
Reference in New Issue
Block a user