mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
feat: support batch input in generate()
The `prompt` argument can now be either a `str` or `list[str]`. The change to `generate()` is backwards-compatible. The changes to `generate_step()`, `top_p_sampling()`, and `min_p_sampling()` are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred.
This commit is contained in:
parent
1003a8b2dd
commit
280b3784d4
@ -26,7 +26,10 @@ def min_p_sampling(
|
|||||||
0.99-0.8 range.
|
0.99-0.8 range.
|
||||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
be filtered. Default: ``1``.
|
be filtered. Default: ``1``.
|
||||||
|
temperature: Temperature parameter for softmax distribution reshaping.
|
||||||
|
Returns:
|
||||||
|
token(s) selected based on the min-p criterion.
|
||||||
|
Shape: same as logits, but with the last dimension having size 1.
|
||||||
"""
|
"""
|
||||||
if not (0 <= min_p <= 1.0):
|
if not (0 <= min_p <= 1.0):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -39,14 +42,14 @@ def min_p_sampling(
|
|||||||
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
|
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
|
||||||
|
|
||||||
# Softmax probabilities
|
# Softmax probabilities
|
||||||
probs = mx.softmax(logits * (1 / temperature), axis=-1)
|
probs = mx.softmax(logits / temperature, axis=-1)
|
||||||
|
|
||||||
# Indices sorted in decreasing order
|
# Indices sorted in decreasing order
|
||||||
sorted_indices = mx.argsort(-logits).squeeze(0)
|
sorted_indices = mx.argsort(-logits)
|
||||||
sorted_probs = probs[..., sorted_indices]
|
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
||||||
|
|
||||||
# Top probability
|
# Top probability
|
||||||
top_probs = probs[..., sorted_indices[0]]
|
top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1)
|
||||||
|
|
||||||
# Calculate the min_p threshold
|
# Calculate the min_p threshold
|
||||||
scaled_min_p = min_p * top_probs
|
scaled_min_p = min_p * top_probs
|
||||||
@ -58,13 +61,18 @@ def min_p_sampling(
|
|||||||
# Create pool of tokens with probability less than scaled min_p
|
# Create pool of tokens with probability less than scaled min_p
|
||||||
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
|
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
|
||||||
|
|
||||||
# Return sampled token
|
# Return sampled token(s)
|
||||||
sorted_token = mx.random.categorical(mx.log(selected_probs))
|
sampled_indices = mx.random.categorical(mx.log(selected_probs))
|
||||||
return sorted_indices[sorted_token]
|
tokens = mx.take_along_axis(
|
||||||
|
sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1
|
||||||
|
)
|
||||||
|
return tokens.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
def top_p_sampling(
|
||||||
|
logits: mx.array, top_p: float, temperature: float, axis: int = -1
|
||||||
|
) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Apply top-p (nucleus) sampling to logits.
|
Apply top-p (nucleus) sampling to logits.
|
||||||
|
|
||||||
@ -72,29 +80,35 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
|||||||
logits: The logits from the model's output.
|
logits: The logits from the model's output.
|
||||||
top_p: The cumulative probability threshold for top-p filtering.
|
top_p: The cumulative probability threshold for top-p filtering.
|
||||||
temperature: Temperature parameter for softmax distribution reshaping.
|
temperature: Temperature parameter for softmax distribution reshaping.
|
||||||
|
axis: The axis along which to apply top-p sampling.
|
||||||
Returns:
|
Returns:
|
||||||
token selected based on the top-p criterion.
|
token(s) selected based on the top-p criterion.
|
||||||
"""
|
"""
|
||||||
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
# Apply temperature and compute softmax
|
||||||
probs = mx.softmax(logits * (1 / temperature), axis=-1)
|
probs = mx.softmax(logits / temperature, axis=axis)
|
||||||
|
|
||||||
# sort probs in ascending order
|
# Sort probs in descending order
|
||||||
sorted_indices = mx.argsort(probs, axis=-1)
|
sorted_indices = mx.argsort(-probs, axis=axis)
|
||||||
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)
|
||||||
|
|
||||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
# Compute cumulative probabilities
|
||||||
|
cumulative_probs = mx.cumsum(sorted_probs, axis=axis)
|
||||||
|
|
||||||
# select tokens with cumulative probs below threshold
|
# Create a mask for probs above the threshold
|
||||||
top_probs = mx.where(
|
mask = cumulative_probs <= top_p
|
||||||
cumulative_probs > 1 - top_p,
|
|
||||||
sorted_probs,
|
# Apply the mask to the sorted probabilities
|
||||||
0,
|
masked_probs = sorted_probs * mask
|
||||||
|
|
||||||
|
# Sample from the normalized probabilities
|
||||||
|
sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis)
|
||||||
|
|
||||||
|
# Gather the original token indices
|
||||||
|
tokens = mx.take_along_axis(
|
||||||
|
sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
return tokens.squeeze(axis)
|
||||||
token = sorted_indices.squeeze(0)[sorted_token]
|
|
||||||
|
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
|
@ -410,7 +410,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
top_tokens = []
|
top_tokens = []
|
||||||
for (token, logprobs), _ in zip(
|
for (token, logprobs), _ in zip(
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=prompt,
|
prompts=prompt[None],
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
@ -420,6 +420,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
),
|
),
|
||||||
range(self.max_tokens),
|
range(self.max_tokens),
|
||||||
):
|
):
|
||||||
|
token = token.item()
|
||||||
|
logprobs = logprobs.squeeze(0)
|
||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
logging.debug(detokenizer.text)
|
logging.debug(detokenizer.text)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
@ -497,7 +499,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
for (token, _), _ in zip(
|
for (token, _), _ in zip(
|
||||||
generate_step(
|
generate_step(
|
||||||
prompt=prompt,
|
prompts=prompt[None],
|
||||||
model=self.model,
|
model=self.model,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
top_p=self.top_p,
|
||||||
@ -506,6 +508,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
),
|
),
|
||||||
range(self.max_tokens),
|
range(self.max_tokens),
|
||||||
):
|
):
|
||||||
|
token = token.item()
|
||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
logging.debug(detokenizer.text)
|
logging.debug(detokenizer.text)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
@ -117,12 +117,12 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
|||||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||||
"""
|
"""
|
||||||
if len(generated_tokens) > 0:
|
if len(generated_tokens) > 0:
|
||||||
indices = mx.array([token for token in generated_tokens])
|
indices = generated_tokens
|
||||||
selected_logits = logits[:, indices]
|
selected_logits = mx.take_along_axis(logits, indices, axis=-1)
|
||||||
selected_logits = mx.where(
|
selected_logits = mx.where(
|
||||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||||
)
|
)
|
||||||
logits[:, indices] = selected_logits
|
logits[mx.arange(indices.shape[0])[:, None], indices] = selected_logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ def make_kv_caches(
|
|||||||
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompts: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
temp: float = 0.0,
|
temp: float = 0.0,
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
@ -164,7 +164,7 @@ def generate_step(
|
|||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompts (mx.array): The input prompt(s). Shape: ``(bs, seq_len)``.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
Default: ``0``.
|
Default: ``0``.
|
||||||
@ -185,27 +185,33 @@ def generate_step(
|
|||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||||
one token and a vector of log probabilities.
|
one token and a vector of log probabilities per prompt.
|
||||||
|
Shapes: ``(bs, 1), (bs, vocab_size)``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
if prompts.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Shape of prompts should be (bs, seq_len), got {prompts.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample(logits: mx.array) -> Tuple[mx.array, mx.array]:
|
||||||
if logit_bias:
|
if logit_bias:
|
||||||
indices = mx.array(list(logit_bias.keys()))
|
indices = mx.array(list(logit_bias.keys()))
|
||||||
values = mx.array(list(logit_bias.values()))
|
values = mx.array(list(logit_bias.values()))
|
||||||
logits[:, indices] += values
|
logits[:, indices] += values
|
||||||
logprobs = logits - mx.logsumexp(logits)
|
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||||
|
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
token = mx.argmax(logits, axis=-1)
|
tokens = mx.argmax(logits, axis=-1)
|
||||||
else:
|
else:
|
||||||
if top_p > 0 and top_p < 1.0:
|
if top_p > 0 and top_p < 1.0:
|
||||||
token = top_p_sampling(logits, top_p, temp)
|
tokens = top_p_sampling(logits, top_p, temp)
|
||||||
elif min_p != 0.0:
|
elif min_p != 0.0:
|
||||||
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
tokens = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
||||||
else:
|
else:
|
||||||
token = categorical_sampling(logits, temp)
|
tokens = categorical_sampling(logits, temp)
|
||||||
|
|
||||||
return token, logprobs
|
return mx.expand_dims(tokens, axis=-1), logprobs
|
||||||
|
|
||||||
if repetition_penalty and (
|
if repetition_penalty and (
|
||||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||||
@ -214,7 +220,7 @@ def generate_step(
|
|||||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||||
)
|
)
|
||||||
|
|
||||||
y = prompt
|
y = prompts
|
||||||
|
|
||||||
# Create the KV cache for generation
|
# Create the KV cache for generation
|
||||||
cache = make_kv_caches(model, max_kv_size)
|
cache = make_kv_caches(model, max_kv_size)
|
||||||
@ -229,14 +235,14 @@ def generate_step(
|
|||||||
c.update_and_fetch(h[0], h[1])
|
c.update_and_fetch(h[0], h[1])
|
||||||
mx.eval([c.state for c in cache])
|
mx.eval([c.state for c in cache])
|
||||||
|
|
||||||
repetition_context = prompt.tolist()
|
repetition_context = prompts
|
||||||
|
|
||||||
if repetition_context_size:
|
if repetition_context_size:
|
||||||
repetition_context = repetition_context[-repetition_context_size:]
|
repetition_context = repetition_context[:, -repetition_context_size:]
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
nonlocal repetition_context
|
nonlocal repetition_context
|
||||||
logits = model(y[None], cache=cache)
|
logits = model(y, cache=cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
if repetition_penalty:
|
if repetition_penalty:
|
||||||
@ -244,27 +250,27 @@ def generate_step(
|
|||||||
logits, repetition_context, repetition_penalty
|
logits, repetition_context, repetition_penalty
|
||||||
)
|
)
|
||||||
y, logprobs = sample(logits)
|
y, logprobs = sample(logits)
|
||||||
repetition_context.append(y.item())
|
repetition_context = mx.concatenate([repetition_context, y], axis=-1)
|
||||||
else:
|
else:
|
||||||
y, logprobs = sample(logits)
|
y, logprobs = sample(logits)
|
||||||
|
|
||||||
if repetition_context_size:
|
if repetition_context_size:
|
||||||
if len(repetition_context) > repetition_context_size:
|
if repetition_context.shape[1] > repetition_context_size:
|
||||||
repetition_context = repetition_context[-repetition_context_size:]
|
repetition_context = repetition_context[:, -repetition_context_size:]
|
||||||
return y, logprobs.squeeze(0)
|
return y, logprobs
|
||||||
|
|
||||||
while y.size > prefill_step_size:
|
while y.shape[1] > prefill_step_size:
|
||||||
model(y[:prefill_step_size][None], cache=cache)
|
model(y[:, :prefill_step_size], cache=cache)
|
||||||
mx.eval([c.state for c in cache])
|
mx.eval([c.state for c in cache])
|
||||||
y = y[prefill_step_size:]
|
y = y[:, prefill_step_size:]
|
||||||
|
|
||||||
y, logprobs = _step(y)
|
y, logprobs = _step(y)
|
||||||
|
|
||||||
mx.async_eval(y)
|
mx.async_eval(y)
|
||||||
while True:
|
while True:
|
||||||
next_y, next_logprobs = _step(y)
|
next_y, next_logprobs = _step(y)
|
||||||
mx.async_eval(next_y)
|
mx.async_eval(next_y)
|
||||||
yield y.item(), logprobs
|
mx.eval(y)
|
||||||
|
yield y, logprobs
|
||||||
y, logprobs = next_y, next_logprobs
|
y, logprobs = next_y, next_logprobs
|
||||||
|
|
||||||
|
|
||||||
@ -296,9 +302,10 @@ def stream_generate(
|
|||||||
|
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
for (token, _), n in zip(
|
for (token, _), n in zip(
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens[None], model, **kwargs),
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
):
|
):
|
||||||
|
token = token.item()
|
||||||
if token == tokenizer.eos_token_id:
|
if token == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
detokenizer.add_token(token)
|
detokenizer.add_token(token)
|
||||||
@ -313,19 +320,19 @@ def stream_generate(
|
|||||||
def generate(
|
def generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: Union[str, List[str]],
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
formatter: Optional[Callable] = None,
|
formatter: Optional[Callable] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, Generator[str, None, None]]:
|
) -> Union[str, List[str]]:
|
||||||
"""
|
"""
|
||||||
Generate a complete response from the model.
|
Generate a complete response from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The language model.
|
model (nn.Module): The language model.
|
||||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||||
prompt (str): The string prompt.
|
prompts (str): The string prompt(s).
|
||||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||||
verbose (bool): If ``True``, print tokens and timing information.
|
verbose (bool): If ``True``, print tokens and timing information.
|
||||||
Default: ``False``.
|
Default: ``False``.
|
||||||
@ -334,56 +341,82 @@ def generate(
|
|||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
See :func:`generate_step` for more details.
|
See :func:`generate_step` for more details.
|
||||||
"""
|
"""
|
||||||
|
is_batch = isinstance(prompt, list)
|
||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
if verbose:
|
if is_batch:
|
||||||
print("=" * 10)
|
tokenizer._tokenizer.padding_side = "left"
|
||||||
print("Prompt:", prompt)
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer._tokenizer.pad_token = tokenizer.eos_token
|
||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
detokenizer = tokenizer.detokenizer
|
prompt_tokens = mx.array(
|
||||||
|
tokenizer._tokenizer(prompt, padding=True)["input_ids"]
|
||||||
|
)
|
||||||
|
output_toks = []
|
||||||
|
else:
|
||||||
|
prompt_tokens = mx.array(tokenizer.encode(prompt))[None]
|
||||||
|
detokenizer = tokenizer.detokenizer
|
||||||
|
detokenizer.reset()
|
||||||
|
if verbose:
|
||||||
|
print("=" * 10)
|
||||||
|
print("Prompt:", prompt)
|
||||||
|
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
detokenizer.reset()
|
|
||||||
|
|
||||||
for (token, logprobs), n in zip(
|
for (tokens, logprobs), n in zip(
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
):
|
):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
if token == tokenizer.eos_token_id:
|
if (tokens == tokenizer.eos_token_id).all():
|
||||||
break
|
break
|
||||||
detokenizer.add_token(token)
|
if is_batch:
|
||||||
|
output_toks.append(tokens)
|
||||||
|
else:
|
||||||
|
token = tokens.item()
|
||||||
|
logprobs = logprobs.squeeze(0)
|
||||||
|
detokenizer.add_token(token)
|
||||||
|
if verbose:
|
||||||
|
if formatter:
|
||||||
|
# We have to finalize so that the prob corresponds to the last segment
|
||||||
|
detokenizer.finalize()
|
||||||
|
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||||
|
else:
|
||||||
|
print(detokenizer.last_segment, end="", flush=True)
|
||||||
|
|
||||||
if verbose:
|
if is_batch:
|
||||||
if formatter:
|
output_toks = mx.concatenate(output_toks, axis=1)
|
||||||
# We have to finalize so that the prob corresponds to the last segment
|
token_count = output_toks.size
|
||||||
detokenizer.finalize()
|
response = [
|
||||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
|
||||||
else:
|
for response in tokenizer.batch_decode(output_toks.tolist())
|
||||||
print(detokenizer.last_segment, end="", flush=True)
|
]
|
||||||
|
else:
|
||||||
token_count = n + 1
|
token_count = n
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
|
response = detokenizer.text
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
gen_time = time.perf_counter() - tic
|
gen_time = time.perf_counter() - tic
|
||||||
print(detokenizer.last_segment, flush=True)
|
if token_count <= 0:
|
||||||
print("=" * 10)
|
|
||||||
if token_count == 0:
|
|
||||||
print("No tokens generated for this prompt")
|
print("No tokens generated for this prompt")
|
||||||
return
|
if is_batch:
|
||||||
|
for p, resp in zip(prompt, response):
|
||||||
|
print("=" * 10)
|
||||||
|
print("Prompt:", p)
|
||||||
|
print(resp)
|
||||||
|
else:
|
||||||
|
print(detokenizer.last_segment, flush=True)
|
||||||
prompt_tps = prompt_tokens.size / prompt_time
|
prompt_tps = prompt_tokens.size / prompt_time
|
||||||
gen_tps = (token_count - 1) / gen_time
|
gen_tps = token_count / gen_time
|
||||||
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
print("=" * 10)
|
||||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
|
||||||
|
|
||||||
return detokenizer.text
|
return response
|
||||||
|
|
||||||
|
|
||||||
def load_config(model_path: Path) -> dict:
|
def load_config(model_path: Path) -> dict:
|
||||||
|
Loading…
Reference in New Issue
Block a user