mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Remove unnecessary mx.where
This commit is contained in:
parent
956da0ddc7
commit
e6031a78e4
@ -216,7 +216,7 @@ def apply_top_p(logits: mx.array, top_p: float) -> mx.array:
|
|||||||
original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1)
|
original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1)
|
||||||
|
|
||||||
# Convert back to logits and return
|
# Convert back to logits and return
|
||||||
return mx.log(mx.where(original_order_probs > 0, original_order_probs, 0))
|
return mx.log(original_order_probs)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
Loading…
Reference in New Issue
Block a user