mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-28 17:08:07 +08:00
fix(mlx-lm): sorted probs in top_p implementation. (#610)
* fix(mlx-lm): the top p imp * chore: address comment
This commit is contained in:
@@ -22,7 +22,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
|
||||
# sort probs in ascending order
|
||||
sorted_indices = mx.argsort(probs, axis=-1)
|
||||
sorted_probs = probs[..., sorted_indices]
|
||||
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
||||
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user