Use async eval (#670)

* Use async eval

* bump

* bump

* remove workaround for bfloat cumsum
This commit is contained in:
Awni Hannun
2024-04-11 13:18:23 -07:00
committed by GitHub
parent 0250f6f38e
commit 9c5554d8ee
4 changed files with 15 additions and 11 deletions

View File

@@ -169,7 +169,8 @@ def generate_step(
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
while True:
def _step(y):
nonlocal cache, repetition_context
logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :]
@@ -185,7 +186,16 @@ def generate_step(
if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
yield y, prob
return y, prob
y, prob = _step(y)
while True:
sync = mx.async_eval(y)
next_out = _step(y)
sync.wait()
yield y.item(), prob
y, prob = next_out
def generate(
@@ -240,7 +250,6 @@ def generate(
),
range(max_tokens),
):
token = token.item()
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
@@ -260,8 +269,8 @@ def generate(
detokenizer.finalize()
if verbose:
print(detokenizer.last_segment, flush=True)
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")