Remove unnecessary mx.where

This commit is contained in:
Neil Mehta 2025-03-10 11:27:14 -04:00
parent 956da0ddc7
commit e6031a78e4

View File

@ -216,7 +216,7 @@ def apply_top_p(logits: mx.array, top_p: float) -> mx.array:
original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1) original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1)
# Convert back to logits and return # Convert back to logits and return
return mx.log(mx.where(original_order_probs > 0, original_order_probs, 0)) return mx.log(original_order_probs)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)