mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
Use async eval (#670)
* Use async eval * bump * bump * remove workaround for bfloat cumsum
This commit is contained in:
@@ -12,11 +12,6 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
Returns:
|
||||
token selected based on the top-p criterion.
|
||||
"""
|
||||
if (
|
||||
logits.dtype == mx.bfloat16
|
||||
): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
|
||||
logits = logits.astype(mx.float32)
|
||||
|
||||
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||
probs = mx.softmax(logits / temperature, axis=-1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user