Use async eval (#670)

* Use async eval

* bump

* bump

* remove workaround for bfloat cumsum
This commit is contained in:
Awni Hannun 2024-04-11 13:18:23 -07:00 committed by GitHub
parent 0250f6f38e
commit 9c5554d8ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 11 deletions

View File

@ -1,4 +1,4 @@
mlx>=0.8 mlx>=0.10
numpy numpy
transformers>=4.39.3 transformers>=4.39.3
protobuf protobuf

View File

@ -12,11 +12,6 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
Returns: Returns:
token selected based on the top-p criterion. token selected based on the top-p criterion.
""" """
if (
logits.dtype == mx.bfloat16
): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
logits = logits.astype(mx.float32)
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits / temperature, axis=-1) probs = mx.softmax(logits / temperature, axis=-1)

View File

@ -169,7 +169,8 @@ def generate_step(
if repetition_context_size: if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:] repetition_context = repetition_context[-repetition_context_size:]
while True: def _step(y):
nonlocal cache, repetition_context
logits, cache = model(y[None], cache=cache) logits, cache = model(y[None], cache=cache)
logits = logits[:, -1, :] logits = logits[:, -1, :]
@ -185,7 +186,16 @@ def generate_step(
if repetition_context_size: if repetition_context_size:
if len(repetition_context) > repetition_context_size: if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:] repetition_context = repetition_context[-repetition_context_size:]
yield y, prob return y, prob
y, prob = _step(y)
while True:
sync = mx.async_eval(y)
next_out = _step(y)
sync.wait()
yield y.item(), prob
y, prob = next_out
def generate( def generate(
@ -240,7 +250,6 @@ def generate(
), ),
range(max_tokens), range(max_tokens),
): ):
token = token.item()
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
tic = time.perf_counter() tic = time.perf_counter()
@ -260,8 +269,8 @@ def generate(
detokenizer.finalize() detokenizer.finalize()
if verbose: if verbose:
print(detokenizer.last_segment, flush=True)
gen_time = time.perf_counter() - tic gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10) print("=" * 10)
if token_count == 0: if token_count == 0:
print("No tokens generated for this prompt") print("No tokens generated for this prompt")

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.8.0" __version__ = "0.9.0"