mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
comments
This commit is contained in:
parent
21a05d14d2
commit
3c63be8c55
@ -77,7 +77,7 @@ def setup_arg_parser():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-tokens-to-keep",
|
"--min-tokens-to-keep",
|
||||||
type=float,
|
type=int,
|
||||||
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
||||||
help="Minimum tokens to keep for min-p sampling.",
|
help="Minimum tokens to keep for min-p sampling.",
|
||||||
)
|
)
|
||||||
|
@ -183,7 +183,7 @@ def generate_step(
|
|||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
*,
|
*,
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 256,
|
||||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
||||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
@ -206,7 +206,8 @@ def generate_step(
|
|||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
|
||||||
|
generator. Default: ``256``.
|
||||||
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
|
||||||
token from a vector of log probabilities. Default: ``None``.
|
token from a vector of log probabilities. Default: ``None``.
|
||||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
@ -296,7 +297,7 @@ def generate_step(
|
|||||||
mx.async_eval(y, logprobs)
|
mx.async_eval(y, logprobs)
|
||||||
n = 0
|
n = 0
|
||||||
while True:
|
while True:
|
||||||
if n < max_tokens:
|
if n != max_tokens:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y, next_logprobs)
|
mx.async_eval(next_y, next_logprobs)
|
||||||
if n == 0:
|
if n == 0:
|
||||||
@ -307,7 +308,6 @@ def generate_step(
|
|||||||
yield y.item(), logprobs
|
yield y.item(), logprobs
|
||||||
if n % 256 == 0:
|
if n % 256 == 0:
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
if n < max_tokens:
|
|
||||||
y, logprobs = next_y, next_logprobs
|
y, logprobs = next_y, next_logprobs
|
||||||
n += 1
|
n += 1
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user