This commit is contained in:
L Lllvvuu 2024-12-27 15:52:42 -08:00
parent fdd16caf7a
commit 30e98c85c1
No known key found for this signature in database
GPG Key ID: CFAD5A25056DDD0F

View File

@ -296,9 +296,9 @@ def generate_step(
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
def _step(y):
if y.ndim == 1:
y = mx.expand_dims(y, axis=-1)
with mx.stream(generation_stream):
if y.ndim == 1:
y = mx.expand_dims(y, axis=-1)
logits = model(
y,
cache=prompt_cache,
@ -514,6 +514,7 @@ def batch_generate(
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
# to extend `mask` below, and handling position_ids (see TODO below)
raise ValueError("Batch generation does not support prompt_cache yet.")
tokenizer = copy.deepcopy(tokenizer)
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
# TODO: left-shift position_ids for absolute/rotary positional encodings