Use async eval (#670)

* Use async eval

* bump

* bump

* remove workaround for bfloat cumsum
This commit is contained in:
Awni Hannun
2024-04-11 13:18:23 -07:00
committed by GitHub
parent 0250f6f38e
commit 9c5554d8ee
4 changed files with 15 additions and 11 deletions

View File

@@ -12,11 +12,6 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
Returns:
token selected based on the top-p criterion.
"""
if (
logits.dtype == mx.bfloat16
): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
logits = logits.astype(mx.float32)
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits / temperature, axis=-1)