mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 09:48:54 +08:00
Use async eval (#670)
* Use async eval * bump * bump * remove workaround for bfloat cumsum
This commit is contained in:
@@ -169,7 +169,8 @@ def generate_step(
|
||||
if repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
|
||||
while True:
|
||||
def _step(y):
|
||||
nonlocal cache, repetition_context
|
||||
logits, cache = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
@@ -185,7 +186,16 @@ def generate_step(
|
||||
if repetition_context_size:
|
||||
if len(repetition_context) > repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
yield y, prob
|
||||
return y, prob
|
||||
|
||||
y, prob = _step(y)
|
||||
|
||||
while True:
|
||||
sync = mx.async_eval(y)
|
||||
next_out = _step(y)
|
||||
sync.wait()
|
||||
yield y.item(), prob
|
||||
y, prob = next_out
|
||||
|
||||
|
||||
def generate(
|
||||
@@ -240,7 +250,6 @@ def generate(
|
||||
),
|
||||
range(max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
@@ -260,8 +269,8 @@ def generate(
|
||||
detokenizer.finalize()
|
||||
|
||||
if verbose:
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
gen_time = time.perf_counter() - tic
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
print("=" * 10)
|
||||
if token_count == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
|
||||
Reference in New Issue
Block a user