mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 07:30:06 +08:00
tweaks
This commit is contained in:
parent
fdd16caf7a
commit
30e98c85c1
@ -296,9 +296,9 @@ def generate_step(
|
||||
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
|
||||
|
||||
def _step(y):
|
||||
if y.ndim == 1:
|
||||
y = mx.expand_dims(y, axis=-1)
|
||||
with mx.stream(generation_stream):
|
||||
if y.ndim == 1:
|
||||
y = mx.expand_dims(y, axis=-1)
|
||||
logits = model(
|
||||
y,
|
||||
cache=prompt_cache,
|
||||
@ -514,6 +514,7 @@ def batch_generate(
|
||||
# we have <pad>text<pad>text. Should involve taking `prompt_cache_lens`
|
||||
# to extend `mask` below, and handling position_ids (see TODO below)
|
||||
raise ValueError("Batch generation does not support prompt_cache yet.")
|
||||
tokenizer = copy.deepcopy(tokenizer)
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
# TODO: left-shift position_ids for absolute/rotary positional encodings
|
||||
|
Loading…
Reference in New Issue
Block a user