This commit is contained in:
Awni Hannun 2024-12-03 06:28:21 -08:00
parent 21a05d14d2
commit 3c63be8c55
2 changed files with 6 additions and 6 deletions

View File

@ -77,7 +77,7 @@ def setup_arg_parser():
) )
parser.add_argument( parser.add_argument(
"--min-tokens-to-keep", "--min-tokens-to-keep",
type=float, type=int,
default=DEFAULT_MIN_TOKENS_TO_KEEP, default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.", help="Minimum tokens to keep for min-p sampling.",
) )

View File

@ -183,7 +183,7 @@ def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
*, *,
max_tokens: int = 100, max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None, sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
@ -206,7 +206,8 @@ def generate_step(
Args: Args:
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
max_tokens (int): The maximum number of tokens. Default: ``100``. max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``. token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
@ -296,7 +297,7 @@ def generate_step(
mx.async_eval(y, logprobs) mx.async_eval(y, logprobs)
n = 0 n = 0
while True: while True:
if n < max_tokens: if n != max_tokens:
next_y, next_logprobs = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs) mx.async_eval(next_y, next_logprobs)
if n == 0: if n == 0:
@ -307,7 +308,6 @@ def generate_step(
yield y.item(), logprobs yield y.item(), logprobs
if n % 256 == 0: if n % 256 == 0:
mx.metal.clear_cache() mx.metal.clear_cache()
if n < max_tokens:
y, logprobs = next_y, next_logprobs y, logprobs = next_y, next_logprobs
n += 1 n += 1