diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 02788d97..85dfaa53 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.8 +mlx>=0.10 numpy transformers>=4.39.3 protobuf diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 1953aeea..f22ce2d7 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -12,11 +12,6 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr Returns: token selected based on the top-p criterion. """ - if ( - logits.dtype == mx.bfloat16 - ): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16 - logits = logits.astype(mx.float32) - # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 probs = mx.softmax(logits / temperature, axis=-1) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4a7b8142..d5a03270 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -169,7 +169,8 @@ def generate_step( if repetition_context_size: repetition_context = repetition_context[-repetition_context_size:] - while True: + def _step(y): + nonlocal cache, repetition_context logits, cache = model(y[None], cache=cache) logits = logits[:, -1, :] @@ -185,7 +186,16 @@ def generate_step( if repetition_context_size: if len(repetition_context) > repetition_context_size: repetition_context = repetition_context[-repetition_context_size:] - yield y, prob + return y, prob + + y, prob = _step(y) + + while True: + sync = mx.async_eval(y) + next_out = _step(y) + sync.wait() + yield y.item(), prob + y, prob = next_out def generate( @@ -240,7 +250,6 @@ def generate( ), range(max_tokens), ): - token = token.item() if n == 0: prompt_time = time.perf_counter() - tic tic = time.perf_counter() @@ -260,8 +269,8 @@ def generate( detokenizer.finalize() if verbose: - print(detokenizer.last_segment, flush=True) gen_time = time.perf_counter() - tic + print(detokenizer.last_segment, flush=True) print("=" * 10) if token_count == 0: print("No tokens generated for this prompt") diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index e339bd95..1e0aa30d 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.8.0" +__version__ = "0.9.0"