diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 185a0698..0925e469 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -296,9 +296,9 @@ def generate_step( prompt_progress_callback = prompt_progress_callback or (lambda *_: None) def _step(y): + if y.ndim == 1: + y = mx.expand_dims(y, axis=-1) with mx.stream(generation_stream): - if y.ndim == 1: - y = mx.expand_dims(y, axis=-1) logits = model( y, cache=prompt_cache, @@ -514,6 +514,7 @@ def batch_generate( # we have texttext. Should involve taking `prompt_cache_lens` # to extend `mask` below, and handling position_ids (see TODO below) raise ValueError("Batch generation does not support prompt_cache yet.") + tokenizer = copy.deepcopy(tokenizer) if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) # TODO: left-shift position_ids for absolute/rotary positional encodings