feat: support batch input in generate()

The `prompt` argument can now be either a `str` or `list[str]`.

The change to `generate()` is backwards-compatible.

The changes to `generate_step()`, `top_p_sampling()`, and
`min_p_sampling()` are backwards-incompatible in order to unify shapes;
this could be changed by adding a few if-statements, if preferred.
This commit is contained in:
L Lllvvuu 2024-08-21 00:46:19 +08:00
parent 1003a8b2dd
commit 280b3784d4
No known key found for this signature in database
GPG Key ID: CFAD5A25056DDD0F
3 changed files with 137 additions and 87 deletions

View File

@ -26,7 +26,10 @@ def min_p_sampling(
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.
temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token(s) selected based on the min-p criterion.
Shape: same as logits, but with the last dimension having size 1.
"""
if not (0 <= min_p <= 1.0):
raise ValueError(
@ -39,14 +42,14 @@ def min_p_sampling(
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
# Softmax probabilities
probs = mx.softmax(logits * (1 / temperature), axis=-1)
probs = mx.softmax(logits / temperature, axis=-1)
# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0)
sorted_probs = probs[..., sorted_indices]
sorted_indices = mx.argsort(-logits)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)
# Top probability
top_probs = probs[..., sorted_indices[0]]
top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1)
# Calculate the min_p threshold
scaled_min_p = min_p * top_probs
@ -58,13 +61,18 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
# Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs))
return sorted_indices[sorted_token]
# Return sampled token(s)
sampled_indices = mx.random.categorical(mx.log(selected_probs))
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1
)
return tokens.squeeze(-1)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
def top_p_sampling(
logits: mx.array, top_p: float, temperature: float, axis: int = -1
) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.
@ -72,29 +80,35 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
axis: The axis along which to apply top-p sampling.
Returns:
token selected based on the top-p criterion.
token(s) selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Apply temperature and compute softmax
probs = mx.softmax(logits / temperature, axis=axis)
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
# Sort probs in descending order
sorted_indices = mx.argsort(-probs, axis=axis)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
# Compute cumulative probabilities
cumulative_probs = mx.cumsum(sorted_probs, axis=axis)
# select tokens with cumulative probs below threshold
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
0,
# Create a mask for probs above the threshold
mask = cumulative_probs <= top_p
# Apply the mask to the sorted probabilities
masked_probs = sorted_probs * mask
# Sample from the normalized probabilities
sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis)
# Gather the original token indices
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis
)
sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
return token
return tokens.squeeze(axis)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)

View File

@ -410,7 +410,7 @@ class APIHandler(BaseHTTPRequestHandler):
top_tokens = []
for (token, logprobs), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
temp=self.temperature,
top_p=self.top_p,
@ -420,6 +420,8 @@ class APIHandler(BaseHTTPRequestHandler):
),
range(self.max_tokens),
):
token = token.item()
logprobs = logprobs.squeeze(0)
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
@ -497,7 +499,7 @@ class APIHandler(BaseHTTPRequestHandler):
for (token, _), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
temp=self.temperature,
top_p=self.top_p,
@ -506,6 +508,7 @@ class APIHandler(BaseHTTPRequestHandler):
),
range(self.max_tokens),
):
token = token.item()
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)

View File

@ -117,12 +117,12 @@ def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: f
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(generated_tokens) > 0:
indices = mx.array([token for token in generated_tokens])
selected_logits = logits[:, indices]
indices = generated_tokens
selected_logits = mx.take_along_axis(logits, indices, axis=-1)
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, indices] = selected_logits
logits[mx.arange(indices.shape[0])[:, None], indices] = selected_logits
return logits
@ -147,7 +147,7 @@ def make_kv_caches(
def generate_step(
prompt: mx.array,
prompts: mx.array,
model: nn.Module,
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
@ -164,7 +164,7 @@ def generate_step(
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
prompts (mx.array): The input prompt(s). Shape: ``(bs, seq_len)``.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
@ -185,27 +185,33 @@ def generate_step(
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
one token and a vector of log probabilities per prompt.
Shapes: ``(bs, 1), (bs, vocab_size)``.
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
if prompts.ndim != 2:
raise ValueError(
f"Shape of prompts should be (bs, seq_len), got {prompts.shape}"
)
def sample(logits: mx.array) -> Tuple[mx.array, mx.array]:
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
logits[:, indices] += values
logprobs = logits - mx.logsumexp(logits)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
if temp == 0:
token = mx.argmax(logits, axis=-1)
tokens = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
tokens = top_p_sampling(logits, top_p, temp)
elif min_p != 0.0:
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
tokens = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
else:
token = categorical_sampling(logits, temp)
tokens = categorical_sampling(logits, temp)
return token, logprobs
return mx.expand_dims(tokens, axis=-1), logprobs
if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
@ -214,7 +220,7 @@ def generate_step(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
y = prompt
y = prompts
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
@ -229,14 +235,14 @@ def generate_step(
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
repetition_context = prompt.tolist()
repetition_context = prompts
if repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
repetition_context = repetition_context[:, -repetition_context_size:]
def _step(y):
nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = model(y, cache=cache)
logits = logits[:, -1, :]
if repetition_penalty:
@ -244,27 +250,27 @@ def generate_step(
logits, repetition_context, repetition_penalty
)
y, logprobs = sample(logits)
repetition_context.append(y.item())
repetition_context = mx.concatenate([repetition_context, y], axis=-1)
else:
y, logprobs = sample(logits)
if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]
return y, logprobs.squeeze(0)
if repetition_context.shape[1] > repetition_context_size:
repetition_context = repetition_context[:, -repetition_context_size:]
return y, logprobs
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
while y.shape[1] > prefill_step_size:
model(y[:, :prefill_step_size], cache=cache)
mx.eval([c.state for c in cache])
y = y[prefill_step_size:]
y = y[:, prefill_step_size:]
y, logprobs = _step(y)
mx.async_eval(y)
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
yield y.item(), logprobs
mx.eval(y)
yield y, logprobs
y, logprobs = next_y, next_logprobs
@ -296,9 +302,10 @@ def stream_generate(
detokenizer.reset()
for (token, _), n in zip(
generate_step(prompt_tokens, model, **kwargs),
generate_step(prompt_tokens[None], model, **kwargs),
range(max_tokens),
):
token = token.item()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
@ -313,19 +320,19 @@ def stream_generate(
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
prompt: Union[str, List[str]],
max_tokens: int = 100,
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> Union[str, List[str]]:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
prompts (str): The string prompt(s).
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
@ -334,30 +341,44 @@ def generate(
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
is_batch = isinstance(prompt, list)
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if is_batch:
tokenizer._tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer._tokenizer.pad_token = tokenizer.eos_token
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
prompt_tokens = mx.array(
tokenizer._tokenizer(prompt, padding=True)["input_ids"]
)
output_toks = []
else:
prompt_tokens = mx.array(tokenizer.encode(prompt))[None]
detokenizer = tokenizer.detokenizer
detokenizer.reset()
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
tic = time.perf_counter()
detokenizer.reset()
for (token, logprobs), n in zip(
for (tokens, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
if (tokens == tokenizer.eos_token_id).all():
break
if is_batch:
output_toks.append(tokens)
else:
token = tokens.item()
logprobs = logprobs.squeeze(0)
detokenizer.add_token(token)
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
@ -366,24 +387,36 @@ def generate(
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
if is_batch:
output_toks = mx.concatenate(output_toks, axis=1)
token_count = output_toks.size
response = [
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
for response in tokenizer.batch_decode(output_toks.tolist())
]
else:
token_count = n
detokenizer.finalize()
response = detokenizer.text
if verbose:
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10)
if token_count == 0:
if token_count <= 0:
print("No tokens generated for this prompt")
return
if is_batch:
for p, resp in zip(prompt, response):
print("=" * 10)
print("Prompt:", p)
print(resp)
else:
print(detokenizer.last_segment, flush=True)
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
gen_tps = token_count / gen_time
print("=" * 10)
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
return detokenizer.text
return response
def load_config(model_path: Path) -> dict: