From 30e98c85c12865d56bcf28171f7105b21c385f42 Mon Sep 17 00:00:00 2001 From: L Lllvvuu Date: Fri, 27 Dec 2024 15:52:42 -0800 Subject: [PATCH] tweaks --- llms/mlx_lm/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 185a0698..0925e469 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -296,9 +296,9 @@ def generate_step( prompt_progress_callback = prompt_progress_callback or (lambda *_: None) def _step(y): + if y.ndim == 1: + y = mx.expand_dims(y, axis=-1) with mx.stream(generation_stream): - if y.ndim == 1: - y = mx.expand_dims(y, axis=-1) logits = model( y, cache=prompt_cache, @@ -514,6 +514,7 @@ def batch_generate( # we have texttext. Should involve taking `prompt_cache_lens` # to extend `mask` below, and handling position_ids (see TODO below) raise ValueError("Batch generation does not support prompt_cache yet.") + tokenizer = copy.deepcopy(tokenizer) if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) # TODO: left-shift position_ids for absolute/rotary positional encodings