From e6031a78e450b0c0b4ed86321baf1b7a7cab0e0c Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Mon, 10 Mar 2025 11:27:14 -0400 Subject: [PATCH] Remove unnecessary mx.where --- llms/mlx_lm/sample_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index d62c7f75..efc5b556 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -216,7 +216,7 @@ def apply_top_p(logits: mx.array, top_p: float) -> mx.array: original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1) # Convert back to logits and return - return mx.log(mx.where(original_order_probs > 0, original_order_probs, 0)) + return mx.log(original_order_probs) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)