mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 23:49:43 +08:00
comments
This commit is contained in:
parent
21a05d14d2
commit
3c63be8c55
@ -77,7 +77,7 @@ def setup_arg_parser():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-tokens-to-keep",
|
||||
type=float,
|
||||
type=int,
|
||||
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
||||
help="Minimum tokens to keep for min-p sampling.",
|
||||
)
|
||||
|
@ -183,7 +183,7 @@ def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
*,
|
||||
max_tokens: int = 100,
|
||||
max_tokens: int = 256,
|
||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
max_kv_size: Optional[int] = None,
|
||||
@ -206,7 +206,8 @@ def generate_step(
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||
generator. Default: ``256``.
|
||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||
token from a vector of log probabilities. Default: ``None``.
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
@ -296,7 +297,7 @@ def generate_step(
|
||||
mx.async_eval(y, logprobs)
|
||||
n = 0
|
||||
while True:
|
||||
if n < max_tokens:
|
||||
if n != max_tokens:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y, next_logprobs)
|
||||
if n == 0:
|
||||
@ -307,8 +308,7 @@ def generate_step(
|
||||
yield y.item(), logprobs
|
||||
if n % 256 == 0:
|
||||
mx.metal.clear_cache()
|
||||
if n < max_tokens:
|
||||
y, logprobs = next_y, next_logprobs
|
||||
y, logprobs = next_y, next_logprobs
|
||||
n += 1
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user