diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 005c877a..b894b5c4 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -252,7 +252,8 @@ def generate_step( prompt tokens processed so far and the total number of prompt tokens. Yields: - Tuple[mx.array, mx.array]: One token, a vector of log probabilities, and token metadata. + Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log + probabilities, and token metadata. """ y = prompt @@ -368,7 +369,8 @@ def speculative_generate_step( when ``kv_bits`` is non-None. Default: ``0``. Yields: - Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, and token metadata. + Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log + probabilities, and token metadata. """ y = prompt @@ -466,7 +468,7 @@ def speculative_generate_step( y = mx.array([tokens[n]], mx.uint32) draft_y = y - # If we accpeted all the draft tokens, include the last + # If we accepted all the draft tokens, include the last # draft token in the next draft step since it hasn't been # processed yet by the draft model if n == num_draft: