mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Use async eval (#670)
* Use async eval * bump * bump * remove workaround for bfloat cumsum
This commit is contained in:
parent
0250f6f38e
commit
9c5554d8ee
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.8
|
mlx>=0.10
|
||||||
numpy
|
numpy
|
||||||
transformers>=4.39.3
|
transformers>=4.39.3
|
||||||
protobuf
|
protobuf
|
||||||
|
@ -12,11 +12,6 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
Returns:
|
Returns:
|
||||||
token selected based on the top-p criterion.
|
token selected based on the top-p criterion.
|
||||||
"""
|
"""
|
||||||
if (
|
|
||||||
logits.dtype == mx.bfloat16
|
|
||||||
): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
|
|
||||||
logits = logits.astype(mx.float32)
|
|
||||||
|
|
||||||
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||||
probs = mx.softmax(logits / temperature, axis=-1)
|
probs = mx.softmax(logits / temperature, axis=-1)
|
||||||
|
|
||||||
|
@ -169,7 +169,8 @@ def generate_step(
|
|||||||
if repetition_context_size:
|
if repetition_context_size:
|
||||||
repetition_context = repetition_context[-repetition_context_size:]
|
repetition_context = repetition_context[-repetition_context_size:]
|
||||||
|
|
||||||
while True:
|
def _step(y):
|
||||||
|
nonlocal cache, repetition_context
|
||||||
logits, cache = model(y[None], cache=cache)
|
logits, cache = model(y[None], cache=cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
@ -185,7 +186,16 @@ def generate_step(
|
|||||||
if repetition_context_size:
|
if repetition_context_size:
|
||||||
if len(repetition_context) > repetition_context_size:
|
if len(repetition_context) > repetition_context_size:
|
||||||
repetition_context = repetition_context[-repetition_context_size:]
|
repetition_context = repetition_context[-repetition_context_size:]
|
||||||
yield y, prob
|
return y, prob
|
||||||
|
|
||||||
|
y, prob = _step(y)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
sync = mx.async_eval(y)
|
||||||
|
next_out = _step(y)
|
||||||
|
sync.wait()
|
||||||
|
yield y.item(), prob
|
||||||
|
y, prob = next_out
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@ -240,7 +250,6 @@ def generate(
|
|||||||
),
|
),
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
):
|
):
|
||||||
token = token.item()
|
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
@ -260,8 +269,8 @@ def generate(
|
|||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(detokenizer.last_segment, flush=True)
|
|
||||||
gen_time = time.perf_counter() - tic
|
gen_time = time.perf_counter() - tic
|
||||||
|
print(detokenizer.last_segment, flush=True)
|
||||||
print("=" * 10)
|
print("=" * 10)
|
||||||
if token_count == 0:
|
if token_count == 0:
|
||||||
print("No tokens generated for this prompt")
|
print("No tokens generated for this prompt")
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.8.0"
|
__version__ = "0.9.0"
|
||||||
|
Loading…
Reference in New Issue
Block a user