This commit is contained in:
Awni Hannun 2024-12-03 06:28:21 -08:00
parent 21a05d14d2
commit 3c63be8c55
2 changed files with 6 additions and 6 deletions

View File

@ -77,7 +77,7 @@ def setup_arg_parser():
)
parser.add_argument(
"--min-tokens-to-keep",
type=float,
type=int,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)

View File

@ -183,7 +183,7 @@ def generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = 100,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
@ -206,7 +206,8 @@ def generate_step(
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The maximum number of tokens. Default: ``100``.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
@ -296,7 +297,7 @@ def generate_step(
mx.async_eval(y, logprobs)
n = 0
while True:
if n < max_tokens:
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
if n == 0:
@ -307,8 +308,7 @@ def generate_step(
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
if n < max_tokens:
y, logprobs = next_y, next_logprobs
y, logprobs = next_y, next_logprobs
n += 1