mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
Faster sampling with mx.compile (#937)
* faster sampling with compile * fix test
This commit is contained in:
@@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models.base import KVCache
|
||||
from .sample_utils import top_p_sampling
|
||||
from .sample_utils import categorical_sampling, top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import apply_lora_layers
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
@@ -169,7 +169,7 @@ def generate_step(
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
token = top_p_sampling(logits, top_p, temp)
|
||||
else:
|
||||
token = mx.random.categorical(logits * (1 / temp))
|
||||
token = categorical_sampling(logits, temp)
|
||||
|
||||
return token, logprobs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user