[MLX LM] Sampler refactor + a few improvements (#1094)

* starting

* refactor sampler/processor and a few improvements

* fix stream

* fix stream generate

* fix eos handling in stream generate
This commit is contained in:
Awni Hannun
2024-11-07 16:15:24 -08:00
committed by GitHub
parent ed9e81dd58
commit 657b4cc0aa
10 changed files with 259 additions and 239 deletions

View File

@@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
# Local imports
from .models import cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters
@@ -34,6 +34,9 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5
# A stream on the default device just for generation
generation_stream = mx.new_stream(mx.default_device())
class ModelNotFoundError(Exception):
def __init__(self, message):
@@ -137,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
Paper: https://arxiv.org/abs/1909.05858
Args:
logits (mx.array): The logits produced by the language model.
tokens (mx.array): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(tokens) > 0:
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, tokens] = selected_logits
return logits
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if (
kv_bits is not None
@@ -185,7 +165,7 @@ def generate_step(
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
@@ -214,7 +194,7 @@ def generate_step(
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
@@ -224,53 +204,9 @@ def generate_step(
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
def sample(logits: mx.array) -> Tuple[mx.array, float]:
logprobs = logits - mx.logsumexp(logits)
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
elif min_p != 0.0:
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
else:
token = categorical_sampling(logits, temp)
return token, logprobs
if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
):
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
logits_processor = logits_processor or []
if repetition_penalty:
def repetition_penalty_processor(tokens, logits):
return apply_repetition_penalty(
logits, tokens[-repetition_context_size:], repetition_penalty
)
logits_processor.append(repetition_penalty_processor)
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processor.append(logit_bias_processor)
y = prompt
tokens = None
@@ -283,24 +219,31 @@ def generate_step(
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = logits_processors or []
logits_processors.extend(
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
)
def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processors:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processors:
logits = processor(tokens, logits)
for processor in logits_processor:
logits = processor(tokens, logits)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
y, logprobs = sample(logits)
return y, logprobs.squeeze(0)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
@@ -325,43 +268,51 @@ def generate_step(
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
prompt: Union[str, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> Generator[Tuple[str, int, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The ma
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
Tuple[str, int, mx.array]:
The next text segment, token, and vector of log probabilities.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(tokenizer.encode(prompt))
prompt_tokens = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for n, (token, _) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
with wired_limit(model, [generation_stream]):
detokenizer.reset()
for n, (token, logits) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
# Yield the last segment if streaming
yield detokenizer.last_segment
detokenizer.add_token(token)
detokenizer.finalize()
yield detokenizer.last_segment
if n == (max_tokens - 1):
break
yield detokenizer.last_segment, token, logits
detokenizer.finalize()
yield detokenizer.last_segment, token, logits
def generate(
@@ -372,7 +323,7 @@ def generate(
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> str:
"""
Generate a complete response from the model.
@@ -398,7 +349,7 @@ def generate(
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model):
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
@@ -416,8 +367,7 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
@@ -438,7 +388,7 @@ def generate(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
@@ -623,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
f"""
# {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
using mlx-lm version **{__version__}**.
## Use with mlx