mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
feat: support batch input in generate()
The `prompt` argument can now be either a `str` or `list[str]`. The change to `generate()` is backwards-compatible. The changes to `generate_step()`, `top_p_sampling()`, and `min_p_sampling()` are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred.
This commit is contained in:
parent
1003a8b2dd
commit
280b3784d4
@ -26,7 +26,10 @@ def min_p_sampling(
|
||||
0.99-0.8 range.
|
||||
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||
be filtered. Default: ``1``.
|
||||
|
||||
temperature: Temperature parameter for softmax distribution reshaping.
|
||||
Returns:
|
||||
token(s) selected based on the min-p criterion.
|
||||
Shape: same as logits, but with the last dimension having size 1.
|
||||
"""
|
||||
if not (0 <= min_p <= 1.0):
|
||||
raise ValueError(
|
||||
@ -39,14 +42,14 @@ def min_p_sampling(
|
||||
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
|
||||
|
||||
# Softmax probabilities
|
||||
probs = mx.softmax(logits * (1 / temperature), axis=-1)
|
||||
probs = mx.softmax(logits / temperature, axis=-1)
|
||||
|
||||
# Indices sorted in decreasing order
|
||||
sorted_indices = mx.argsort(-logits).squeeze(0)
|
||||
sorted_probs = probs[..., sorted_indices]
|
||||
sorted_indices = mx.argsort(-logits)
|
||||
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
|
||||
|
||||
# Top probability
|
||||
top_probs = probs[..., sorted_indices[0]]
|
||||
top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1)
|
||||
|
||||
# Calculate the min_p threshold
|
||||
scaled_min_p = min_p * top_probs
|
||||
@ -58,13 +61,18 @@ def min_p_sampling(
|
||||
# Create pool of tokens with probability less than scaled min_p
|
||||
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
|
||||
|
||||
# Return sampled token
|
||||
sorted_token = mx.random.categorical(mx.log(selected_probs))
|
||||
return sorted_indices[sorted_token]
|
||||
# Return sampled token(s)
|
||||
sampled_indices = mx.random.categorical(mx.log(selected_probs))
|
||||
tokens = mx.take_along_axis(
|
||||
sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1
|
||||
)
|
||||
return tokens.squeeze(-1)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
||||
def top_p_sampling(
|
||||
logits: mx.array, top_p: float, temperature: float, axis: int = -1
|
||||
) -> mx.array:
|
||||
"""
|
||||
Apply top-p (nucleus) sampling to logits.
|
||||
|
||||
@ -72,29 +80,35 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
|
||||
logits: The logits from the model's output.
|
||||
top_p: The cumulative probability threshold for top-p filtering.
|
||||
temperature: Temperature parameter for softmax distribution reshaping.
|
||||
axis: The axis along which to apply top-p sampling.
|
||||
Returns:
|
||||
token selected based on the top-p criterion.
|
||||
token(s) selected based on the top-p criterion.
|
||||
"""
|
||||
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
||||
probs = mx.softmax(logits * (1 / temperature), axis=-1)
|
||||
# Apply temperature and compute softmax
|
||||
probs = mx.softmax(logits / temperature, axis=axis)
|
||||
|
||||
# sort probs in ascending order
|
||||
sorted_indices = mx.argsort(probs, axis=-1)
|
||||
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
||||
# Sort probs in descending order
|
||||
sorted_indices = mx.argsort(-probs, axis=axis)
|
||||
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)
|
||||
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
||||
# Compute cumulative probabilities
|
||||
cumulative_probs = mx.cumsum(sorted_probs, axis=axis)
|
||||
|
||||
# select tokens with cumulative probs below threshold
|
||||
top_probs = mx.where(
|
||||
cumulative_probs > 1 - top_p,
|
||||
sorted_probs,
|
||||
0,
|
||||
# Create a mask for probs above the threshold
|
||||
mask = cumulative_probs <= top_p
|
||||
|
||||
# Apply the mask to the sorted probabilities
|
||||
masked_probs = sorted_probs * mask
|
||||
|
||||
# Sample from the normalized probabilities
|
||||
sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis)
|
||||
|
||||
# Gather the original token indices
|
||||
tokens = mx.take_along_axis(
|
||||
sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis
|
||||
)
|
||||
|
||||
sorted_token = mx.random.categorical(mx.log(top_probs))
|
||||
token = sorted_indices.squeeze(0)[sorted_token]
|
||||
|
||||
return token
|
||||
return tokens.squeeze(axis)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
|
@ -410,7 +410,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
top_tokens = []
|
||||
for (token, logprobs), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
prompts=prompt[None],
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
@ -420,6 +420,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
logprobs = logprobs.squeeze(0)
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
tokens.append(token)
|
||||
@ -497,7 +499,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
for (token, _), _ in zip(
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
prompts=prompt[None],
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
@ -506,6 +508,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
tokens.append(token)
|
||||
|
@ -117,12 +117,12 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
|
||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||
"""
|
||||
if len(generated_tokens) > 0:
|
||||
indices = mx.array([token for token in generated_tokens])
|
||||
selected_logits = logits[:, indices]
|
||||
indices = generated_tokens
|
||||
selected_logits = mx.take_along_axis(logits, indices, axis=-1)
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||
)
|
||||
logits[:, indices] = selected_logits
|
||||
logits[mx.arange(indices.shape[0])[:, None], indices] = selected_logits
|
||||
return logits
|
||||
|
||||
|
||||
@ -147,7 +147,7 @@ def make_kv_caches(
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
prompts: mx.array,
|
||||
model: nn.Module,
|
||||
temp: float = 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
@ -164,7 +164,7 @@ def generate_step(
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
prompts (mx.array): The input prompt(s). Shape: ``(bs, seq_len)``.
|
||||
model (nn.Module): The model to use for generation.
|
||||
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||
Default: ``0``.
|
||||
@ -185,27 +185,33 @@ def generate_step(
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
one token and a vector of log probabilities.
|
||||
one token and a vector of log probabilities per prompt.
|
||||
Shapes: ``(bs, 1), (bs, vocab_size)``.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
if prompts.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Shape of prompts should be (bs, seq_len), got {prompts.shape}"
|
||||
)
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, mx.array]:
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
logits[:, indices] += values
|
||||
logprobs = logits - mx.logsumexp(logits)
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
tokens = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
token = top_p_sampling(logits, top_p, temp)
|
||||
tokens = top_p_sampling(logits, top_p, temp)
|
||||
elif min_p != 0.0:
|
||||
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
||||
tokens = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
||||
else:
|
||||
token = categorical_sampling(logits, temp)
|
||||
tokens = categorical_sampling(logits, temp)
|
||||
|
||||
return token, logprobs
|
||||
return mx.expand_dims(tokens, axis=-1), logprobs
|
||||
|
||||
if repetition_penalty and (
|
||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||
@ -214,7 +220,7 @@ def generate_step(
|
||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||
)
|
||||
|
||||
y = prompt
|
||||
y = prompts
|
||||
|
||||
# Create the KV cache for generation
|
||||
cache = make_kv_caches(model, max_kv_size)
|
||||
@ -229,14 +235,14 @@ def generate_step(
|
||||
c.update_and_fetch(h[0], h[1])
|
||||
mx.eval([c.state for c in cache])
|
||||
|
||||
repetition_context = prompt.tolist()
|
||||
repetition_context = prompts
|
||||
|
||||
if repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
repetition_context = repetition_context[:, -repetition_context_size:]
|
||||
|
||||
def _step(y):
|
||||
nonlocal repetition_context
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = model(y, cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if repetition_penalty:
|
||||
@ -244,27 +250,27 @@ def generate_step(
|
||||
logits, repetition_context, repetition_penalty
|
||||
)
|
||||
y, logprobs = sample(logits)
|
||||
repetition_context.append(y.item())
|
||||
repetition_context = mx.concatenate([repetition_context, y], axis=-1)
|
||||
else:
|
||||
y, logprobs = sample(logits)
|
||||
|
||||
if repetition_context_size:
|
||||
if len(repetition_context) > repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
return y, logprobs.squeeze(0)
|
||||
if repetition_context.shape[1] > repetition_context_size:
|
||||
repetition_context = repetition_context[:, -repetition_context_size:]
|
||||
return y, logprobs
|
||||
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=cache)
|
||||
while y.shape[1] > prefill_step_size:
|
||||
model(y[:, :prefill_step_size], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
y = y[prefill_step_size:]
|
||||
y = y[:, prefill_step_size:]
|
||||
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y)
|
||||
while True:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y)
|
||||
yield y.item(), logprobs
|
||||
mx.eval(y)
|
||||
yield y, logprobs
|
||||
y, logprobs = next_y, next_logprobs
|
||||
|
||||
|
||||
@ -296,9 +302,10 @@ def stream_generate(
|
||||
|
||||
detokenizer.reset()
|
||||
for (token, _), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
generate_step(prompt_tokens[None], model, **kwargs),
|
||||
range(max_tokens),
|
||||
):
|
||||
token = token.item()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
@ -313,19 +320,19 @@ def stream_generate(
|
||||
def generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
prompt: Union[str, List[str]],
|
||||
max_tokens: int = 100,
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
) -> Union[str, List[str]]:
|
||||
"""
|
||||
Generate a complete response from the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (str): The string prompt.
|
||||
prompts (str): The string prompt(s).
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
verbose (bool): If ``True``, print tokens and timing information.
|
||||
Default: ``False``.
|
||||
@ -334,30 +341,44 @@ def generate(
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
"""
|
||||
is_batch = isinstance(prompt, list)
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
if is_batch:
|
||||
tokenizer._tokenizer.padding_side = "left"
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer._tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
prompt_tokens = mx.array(
|
||||
tokenizer._tokenizer(prompt, padding=True)["input_ids"]
|
||||
)
|
||||
output_toks = []
|
||||
else:
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))[None]
|
||||
detokenizer = tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print("Prompt:", prompt)
|
||||
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
|
||||
for (token, logprobs), n in zip(
|
||||
for (tokens, logprobs), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
range(max_tokens),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
if (tokens == tokenizer.eos_token_id).all():
|
||||
break
|
||||
if is_batch:
|
||||
output_toks.append(tokens)
|
||||
else:
|
||||
token = tokens.item()
|
||||
logprobs = logprobs.squeeze(0)
|
||||
detokenizer.add_token(token)
|
||||
|
||||
if verbose:
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
@ -366,24 +387,36 @@ def generate(
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
token_count = n + 1
|
||||
if is_batch:
|
||||
output_toks = mx.concatenate(output_toks, axis=1)
|
||||
token_count = output_toks.size
|
||||
response = [
|
||||
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
|
||||
for response in tokenizer.batch_decode(output_toks.tolist())
|
||||
]
|
||||
else:
|
||||
token_count = n
|
||||
detokenizer.finalize()
|
||||
response = detokenizer.text
|
||||
|
||||
if verbose:
|
||||
gen_time = time.perf_counter() - tic
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
print("=" * 10)
|
||||
if token_count == 0:
|
||||
if token_count <= 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
if is_batch:
|
||||
for p, resp in zip(prompt, response):
|
||||
print("=" * 10)
|
||||
print("Prompt:", p)
|
||||
print(resp)
|
||||
else:
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
gen_tps = token_count / gen_time
|
||||
print("=" * 10)
|
||||
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||
|
||||
return detokenizer.text
|
||||
return response
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
|
Loading…
Reference in New Issue
Block a user