mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 03:01:34 +08:00
rollback OTHER changes (oopsy)
This commit is contained in:
parent
75fbb7ed34
commit
4823f279d5
@ -263,13 +263,6 @@ class KVCache(_BaseCache):
|
||||
n = min(self.offset, n)
|
||||
self.offset -= n
|
||||
return n
|
||||
def trim_from_behind(self, n):
|
||||
old_size = self.keys.shape[2]
|
||||
self.keys = self.keys[..., -n:, :]
|
||||
self.values = self.values[..., -n:, :]
|
||||
new_size = self.keys.shape[2]
|
||||
trimmed = old_size - new_size
|
||||
self.offset -= trimmed
|
||||
|
||||
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
||||
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
||||
|
@ -404,7 +404,6 @@ def generate(
|
||||
prompt: str,
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
stop_strings: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
@ -433,8 +432,6 @@ def generate(
|
||||
if verbose:
|
||||
print(response.text, end="", flush=True)
|
||||
text += response.text
|
||||
if stop_strings is not None and any(s in text for s in stop_strings):
|
||||
break
|
||||
|
||||
if verbose:
|
||||
print()
|
||||
@ -869,226 +866,3 @@ def convert(
|
||||
|
||||
if upload_repo is not None:
|
||||
upload_to_hub(mlx_path, upload_repo, hf_path)
|
||||
from tqdm import tqdm
|
||||
|
||||
def generate_batched_response(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: Union[str, mx.array, List[int]],
|
||||
batch_size: int,
|
||||
max_tokens: int = 256,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
max_kv_size: Optional[int] = None,
|
||||
prompt_cache: Optional[List[Any]] = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
|
||||
temp: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
min_p: Optional[float] = None,
|
||||
min_tokens_to_keep: Optional[int] = None,
|
||||
verbose: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate multiple responses to the same prompt in parallel and return only the generated
|
||||
sequences (excluding the prompt), stopping at the first EOS token.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
tokenizer (PreTrainedTokenizer or TokenizerWrapper): The tokenizer.
|
||||
prompt (Union[str, mx.array, List[int]]): The input prompt.
|
||||
batch_size (int): Number of responses to generate in parallel.
|
||||
max_tokens (int): Maximum number of generated tokens per sequence.
|
||||
sampler (Callable): Sampler function.
|
||||
logits_processors (List[Callable]): List of logits processors.
|
||||
max_kv_size (int): Maximum KV cache size.
|
||||
prompt_cache (List[Any]): Precomputed prompt cache.
|
||||
prefill_step_size (int): Step size for prompt processing.
|
||||
kv_bits (int): Bits for KV cache quantization.
|
||||
kv_group_size (int): Group size for KV quantization.
|
||||
quantized_kv_start (int): Step to begin quantizing KV.
|
||||
prompt_progress_callback (Callable): Callback for prompt progress.
|
||||
temp (float): Temperature for sampling (deprecated, pass to sampler).
|
||||
repetition_penalty (float): Repetition penalty (deprecated, use logits_processors).
|
||||
repetition_context_size (int): Context size for repetition.
|
||||
top_p (float): Top-p sampling (deprecated, pass to sampler).
|
||||
min_p (float): Minimum p sampling (deprecated, pass to sampler).
|
||||
min_tokens_to_keep (int): Minimum number of tokens to keep.
|
||||
verbose (bool): If True, show a progress bar for token generation.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of decoded response strings for each batch element, excluding the prompt
|
||||
and stopping at the first EOS token.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
# Convert prompt to tokens if necessary
|
||||
if not isinstance(prompt, mx.array):
|
||||
prompt = mx.array(
|
||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||
)
|
||||
|
||||
# Expand prompt to batch
|
||||
prompt_length = prompt.size
|
||||
prompt = mx.expand_dims(prompt, 0) # (1, prompt_length)
|
||||
prompt = mx.repeat(prompt, batch_size, axis=0) # (B, prompt_length)
|
||||
B = batch_size
|
||||
|
||||
if prompt_progress_callback is None:
|
||||
prompt_progress_callback = lambda *_: None
|
||||
|
||||
if temp is not None or top_p is not None or min_tokens_to_keep is not None:
|
||||
print(
|
||||
"[Warning] Specifying sampling arguments directly is deprecated. "
|
||||
"Pass in a `sampler` if needed."
|
||||
)
|
||||
if repetition_penalty is not None:
|
||||
print(
|
||||
"[Warning] Specifying `repetition_penalty` is deprecated. "
|
||||
"Use `logits_processors` instead."
|
||||
)
|
||||
|
||||
sampler = sampler or make_sampler(
|
||||
temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1
|
||||
)
|
||||
logits_processors = logits_processors or make_logits_processors(
|
||||
None, repetition_penalty, repetition_context_size or 20
|
||||
)
|
||||
|
||||
# Create or verify prompt cache
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
# Process the prompt to fill the cache in increments
|
||||
total_prompt_tokens = prompt_length
|
||||
prompt_processed_tokens = 0
|
||||
remaining_prompt = prompt
|
||||
tic = time.perf_counter()
|
||||
with mx.stream(generation_stream):
|
||||
while remaining_prompt.shape[1] > prefill_step_size:
|
||||
model(remaining_prompt[:, :prefill_step_size], cache=prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
|
||||
prompt_processed_tokens += prefill_step_size
|
||||
remaining_prompt = remaining_prompt[:, prefill_step_size:]
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Process any remaining prompt tokens
|
||||
if remaining_prompt.shape[1] > 0:
|
||||
model(remaining_prompt, cache=prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
|
||||
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = (total_prompt_tokens * B) / prompt_time
|
||||
|
||||
# Initialization for generation
|
||||
tokens = prompt
|
||||
finished = mx.zeros((B,), dtype=tokens.dtype)
|
||||
generation_count = 0
|
||||
eos_ids = tokenizer.eos_token_ids
|
||||
|
||||
# Setup progress bar if verbose
|
||||
pbar = None
|
||||
if verbose:
|
||||
if max_tokens >= 0:
|
||||
pbar = tqdm(total=max_tokens, desc="Generating tokens", ncols=80)
|
||||
else:
|
||||
# If we don't have a max_tokens limit, no total is known.
|
||||
# We'll just display a progress bar that counts up.
|
||||
pbar = tqdm(desc="Generating tokens", ncols=80)
|
||||
|
||||
tic = time.perf_counter()
|
||||
|
||||
while True:
|
||||
if (max_tokens >= 0) and (generation_count >= max_tokens):
|
||||
break
|
||||
|
||||
# If all sequences finished, break
|
||||
sum_finished = mx.sum(finished)
|
||||
mx.eval(sum_finished)
|
||||
if sum_finished.item() == B:
|
||||
break
|
||||
|
||||
# Prepare last token
|
||||
next_input = tokens[:, -1:] # (B,1)
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(next_input, cache=prompt_cache)
|
||||
# logits: (B, 1, vocab)
|
||||
logits = logits[:, -1, :] # (B, vocab)
|
||||
|
||||
# Apply logits processors
|
||||
if logits_processors:
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits)
|
||||
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) # (B,vocab)
|
||||
sampled_tokens = sampler(logprobs) # (B,)
|
||||
|
||||
mx.async_eval(sampled_tokens, logprobs)
|
||||
|
||||
# Check EOS
|
||||
is_eos = mx.zeros_like(sampled_tokens).astype(tokens.dtype)
|
||||
for eid in eos_ids:
|
||||
diff = sampled_tokens - eid
|
||||
sq = diff * diff
|
||||
val = 1.0 / (sq + 1.0)
|
||||
mask = val.astype(tokens.dtype)
|
||||
is_eos = is_eos + mask
|
||||
|
||||
ones = mx.ones_like(is_eos)
|
||||
is_eos = mx.minimum(is_eos, ones)
|
||||
finished = mx.maximum(finished, is_eos)
|
||||
|
||||
sampled_tokens = sampled_tokens[:, None] # (B,1)
|
||||
tokens = mx.concatenate([tokens, sampled_tokens], axis=1)
|
||||
|
||||
generation_count += 1
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
if (generation_count % 256) == 0:
|
||||
mx.metal.clear_cache()
|
||||
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
generation_time = time.perf_counter() - tic
|
||||
generation_tps = (generation_count * B) / generation_time if generation_count > 0 else 0.0
|
||||
peak_memory = mx.metal.get_peak_memory() / 1e9
|
||||
|
||||
results = []
|
||||
for i in range(B):
|
||||
seq = tokens[i][prompt_length:].tolist() # Exclude the prompt
|
||||
# Find the first EOS token
|
||||
eos_pos = None
|
||||
for idx, t in enumerate(seq):
|
||||
if t in eos_ids:
|
||||
eos_pos = idx
|
||||
break
|
||||
# Slice up to EOS if found
|
||||
if eos_pos is not None:
|
||||
seq = seq[:eos_pos]
|
||||
text = tokenizer.decode(seq)
|
||||
results.append(text)
|
||||
|
||||
if verbose:
|
||||
print("=" * 10)
|
||||
print(f"Prompt: {total_prompt_tokens} tokens * {B} sequences, {prompt_tps:.3f} tps")
|
||||
print(
|
||||
f"Generation: {generation_count} tokens * {B} sequences, "
|
||||
f"{generation_tps:.3f} tps"
|
||||
)
|
||||
print(f"Peak memory: {peak_memory:.3f} GB")
|
||||
|
||||
return results
|
||||
|
Loading…
Reference in New Issue
Block a user