This commit is contained in:
Matt Clayton 2025-02-10 12:02:14 -05:00
parent 93591970cf
commit fff5daeb85

View File

@ -252,7 +252,8 @@ def generate_step(
prompt tokens processed so far and the total number of prompt tokens. prompt tokens processed so far and the total number of prompt tokens.
Yields: Yields:
Tuple[mx.array, mx.array]: One token, a vector of log probabilities, and token metadata. Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log
probabilities, and token metadata.
""" """
y = prompt y = prompt
@ -368,7 +369,8 @@ def speculative_generate_step(
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, and token metadata. Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log
probabilities, and token metadata.
""" """
y = prompt y = prompt
@ -466,7 +468,7 @@ def speculative_generate_step(
y = mx.array([tokens[n]], mx.uint32) y = mx.array([tokens[n]], mx.uint32)
draft_y = y draft_y = y
# If we accpeted all the draft tokens, include the last # If we accepted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been # draft token in the next draft step since it hasn't been
# processed yet by the draft model # processed yet by the draft model
if n == num_draft: if n == num_draft: