mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-15 01:42:31 +08:00
[MLX LM] Sampler refactor + a few improvements (#1094)
* starting * refactor sampler/processor and a few improvements * fix stream * fix stream generate * fix eos handling in stream generate
This commit is contained in:
@@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models import cache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .sample_utils import make_logits_processors, make_sampler
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
from .tuner.utils import load_adapters
|
||||
@@ -34,6 +34,9 @@ MODEL_REMAPPING = {
|
||||
|
||||
MAX_FILE_SIZE_GB = 5
|
||||
|
||||
# A stream on the default device just for generation
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
class ModelNotFoundError(Exception):
|
||||
def __init__(self, message):
|
||||
@@ -137,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
|
||||
return model_path
|
||||
|
||||
|
||||
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
|
||||
"""
|
||||
Apply repetition penalty to specific logits based on the given context.
|
||||
|
||||
Paper: https://arxiv.org/abs/1909.05858
|
||||
|
||||
Args:
|
||||
logits (mx.array): The logits produced by the language model.
|
||||
tokens (mx.array): A list of N previous tokens.
|
||||
penalty (float): The repetition penalty factor to be applied.
|
||||
|
||||
Returns:
|
||||
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
||||
"""
|
||||
if len(tokens) > 0:
|
||||
selected_logits = logits[:, tokens]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
||||
)
|
||||
logits[:, tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
|
||||
if (
|
||||
kv_bits is not None
|
||||
@@ -185,7 +165,7 @@ def generate_step(
|
||||
max_kv_size: Optional[int] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
@@ -214,7 +194,7 @@ def generate_step(
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
logits. Default: ``None``.
|
||||
kv_bits (int, optional): Number of bits to use for KV cache quantization.
|
||||
@@ -224,53 +204,9 @@ def generate_step(
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
||||
one token and a vector of log probabilities.
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
"""
|
||||
|
||||
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
||||
logprobs = logits - mx.logsumexp(logits)
|
||||
|
||||
if temp == 0:
|
||||
token = mx.argmax(logits, axis=-1)
|
||||
else:
|
||||
if top_p > 0 and top_p < 1.0:
|
||||
token = top_p_sampling(logits, top_p, temp)
|
||||
elif min_p != 0.0:
|
||||
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
||||
else:
|
||||
token = categorical_sampling(logits, temp)
|
||||
|
||||
return token, logprobs
|
||||
|
||||
if repetition_penalty and (
|
||||
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
||||
):
|
||||
raise ValueError(
|
||||
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
||||
)
|
||||
|
||||
logits_processor = logits_processor or []
|
||||
|
||||
if repetition_penalty:
|
||||
|
||||
def repetition_penalty_processor(tokens, logits):
|
||||
return apply_repetition_penalty(
|
||||
logits, tokens[-repetition_context_size:], repetition_penalty
|
||||
)
|
||||
|
||||
logits_processor.append(repetition_penalty_processor)
|
||||
|
||||
if logit_bias:
|
||||
indices = mx.array(list(logit_bias.keys()))
|
||||
values = mx.array(list(logit_bias.values()))
|
||||
|
||||
def logit_bias_processor(_, logits):
|
||||
logits[:, indices] += values
|
||||
return logits
|
||||
|
||||
logits_processor.append(logit_bias_processor)
|
||||
|
||||
y = prompt
|
||||
tokens = None
|
||||
|
||||
@@ -283,24 +219,31 @@ def generate_step(
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
||||
logits_processors = logits_processors or []
|
||||
logits_processors.extend(
|
||||
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
|
||||
)
|
||||
|
||||
def _step(y):
|
||||
with mx.stream(generation_stream):
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
if logits_processors:
|
||||
nonlocal tokens
|
||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||
|
||||
if logits_processor:
|
||||
nonlocal tokens
|
||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||
for processor in logits_processors:
|
||||
logits = processor(tokens, logits)
|
||||
|
||||
for processor in logits_processor:
|
||||
logits = processor(tokens, logits)
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||
)
|
||||
|
||||
y, logprobs = sample(logits)
|
||||
return y, logprobs.squeeze(0)
|
||||
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||
y = sampler(logprobs)
|
||||
return y, logprobs.squeeze(0)
|
||||
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=prompt_cache)
|
||||
@@ -325,43 +268,51 @@ def generate_step(
|
||||
def stream_generate(
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||
prompt: str,
|
||||
prompt: Union[str, List[int]],
|
||||
max_tokens: int = 100,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
) -> Generator[Tuple[str, int, mx.array], None, None]:
|
||||
"""
|
||||
A generator producing text based on the given prompt from the model.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The input prompt.
|
||||
model (nn.Module): The model to use for generation.
|
||||
max_tokens (int): The ma
|
||||
tokenizer (PreTrainedTokenizer): The tokenizer.
|
||||
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
|
||||
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||
See :func:`generate_step` for more details.
|
||||
|
||||
Yields:
|
||||
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
||||
Tuple[str, int, mx.array]:
|
||||
The next text segment, token, and vector of log probabilities.
|
||||
"""
|
||||
if not isinstance(tokenizer, TokenizerWrapper):
|
||||
tokenizer = TokenizerWrapper(tokenizer)
|
||||
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
prompt_tokens = mx.array(
|
||||
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||
)
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
detokenizer.reset()
|
||||
for n, (token, _) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
with wired_limit(model, [generation_stream]):
|
||||
detokenizer.reset()
|
||||
for n, (token, logits) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
# Yield the last segment if streaming
|
||||
yield detokenizer.last_segment
|
||||
detokenizer.add_token(token)
|
||||
|
||||
detokenizer.finalize()
|
||||
yield detokenizer.last_segment
|
||||
if n == (max_tokens - 1):
|
||||
break
|
||||
|
||||
yield detokenizer.last_segment, token, logits
|
||||
|
||||
detokenizer.finalize()
|
||||
yield detokenizer.last_segment, token, logits
|
||||
|
||||
|
||||
def generate(
|
||||
@@ -372,7 +323,7 @@ def generate(
|
||||
verbose: bool = False,
|
||||
formatter: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
) -> Union[str, Generator[str, None, None]]:
|
||||
) -> str:
|
||||
"""
|
||||
Generate a complete response from the model.
|
||||
|
||||
@@ -398,7 +349,7 @@ def generate(
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
with wired_limit(model):
|
||||
with wired_limit(model, [generation_stream]):
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
for n, (token, logprobs) in zip(
|
||||
@@ -416,8 +367,7 @@ def generate(
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
with mx.stream(mx.cpu):
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
@@ -438,7 +388,7 @@ def generate(
|
||||
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
return detokenizer.text
|
||||
@@ -623,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
||||
f"""
|
||||
# {upload_repo}
|
||||
|
||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
|
||||
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
|
||||
using mlx-lm version **{__version__}**.
|
||||
|
||||
## Use with mlx
|
||||
|
||||
|
||||
Reference in New Issue
Block a user