From 3783156072eda739f0f0063f634bfdf23ebecd65 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 27 Oct 2024 10:02:45 -0700 Subject: [PATCH] starting --- llms/mlx_lm/generate.py | 5 +- llms/mlx_lm/sample_utils.py | 90 ++++++++++++++++++++++++ llms/mlx_lm/server.py | 20 ++++-- llms/mlx_lm/utils.py | 133 ++++++++---------------------------- 4 files changed, 138 insertions(+), 110 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 29976da2..1820dd36 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -238,6 +238,8 @@ def main(): raise ValueError("Cannot use --colorize with --verbose=False") formatter = colorprint_by_t0 if args.colorize else None + sampler = make_sampler( + args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, tokenizer, @@ -245,8 +247,7 @@ def main(): args.max_tokens, verbose=args.verbose, formatter=formatter, - temp=args.temp, - top_p=args.top_p, + sampler=sampler, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fa..f1a5c1bb 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -5,6 +5,63 @@ from functools import partial import mlx.core as mx +def make_sampler( + temp: float = 0.0, + top_p: float = 0.0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, +) -> Callable[mx.array, mx.array]: + """ + Make a sampler function for use with ``generate_step``. + + Args: + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. + + Returns: + Callabel[mx.array, mx.array]: + A sampler which takes log-probabilities and returns tokens. + """ + if temp == 0: + return lambda x: mx.argmax(x, axis=-1) + elif top_p > 0 and top_p < 1.0: + return lambda x: top_p_sampling(x, top_p, temp) + elif min_p != 0.0: + return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) + else: + return lambda x: categorical_sampling(x, temp) + + +def make_logits_processors(): + """ + Make logits processors for use with ``generate_step``. + + Args: + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + logit_bias (dictionary, optional): Additive logit bias. + """ + + logits_processors = [] + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + + def logit_bias_processor(_, logits): + logits[:, indices] += values + return logits + + logits_processors.append(logit_bias_processor) + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( logits: mx.array, @@ -100,3 +157,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def categorical_sampling(logits, temp): return mx.random.categorical(logits * (1 / temp)) + + +def repetition_penalty(penalty: float, context_size: int = 20): + """ + Make repetition penalty processor. + + Paper: https://arxiv.org/abs/1909.05858 + + Args: + penalty (float): The repetition penalty factor to be applied. + context_size (int): The number of previous tokens to use. + Default: ``20``. + + Returns: + Callable[[mx.array, List[int]], mx.array]: + The repetition penalty processor. + """ + if penalty < 0 or not isinstance(penalty, float): + raise ValueError(f"penalty must be a non-negative float, got {penalty}") + + def repetition_penalty_processor(logits, tokens): + if len(tokens) > 0: + tokens = tokens[-context_size:] + selected_logits = logits[:, tokens] + selected_logits = mx.where( + selected_logits < 0, + selected_logits * penalty, + selected_logits / penalty, + ) + logits[:, tokens] = selected_logits + return logits + + return repetition_penalty_processor diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ec659969..e0d0921c 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -461,12 +461,13 @@ class APIHandler(BaseHTTPRequestHandler): token_logprobs = [] top_tokens = [] - prompt = self.get_prompt_cache(prompt) + prompt = mx.array(self.get_prompt_cache(prompt)) + tic = time.perf_counter() for _, (token, logprobs) in zip( range(self.max_tokens), generate_step( - prompt=mx.array(prompt), + prompt=prompt, model=self.model, temp=self.temperature, top_p=self.top_p, @@ -476,6 +477,10 @@ class APIHandler(BaseHTTPRequestHandler): prompt_cache=self.prompt_cache.cache, ), ): + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) @@ -507,6 +512,10 @@ class APIHandler(BaseHTTPRequestHandler): if stop_sequence_suffix is None else detokenizer.text[: -len(stop_sequence_suffix)] ) + gen_time = time.perf_counter() - tic + prompt_tps = len(prompt) / prompt_time + gen_tps = len(tokens) / gen_time + peak_mem = mx.metal.get_peak_memory() / 1e9 response = self.generate_response( text, finish_reason, @@ -517,6 +526,9 @@ class APIHandler(BaseHTTPRequestHandler): tokens=tokens, ) + logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {peak_mem:.3f} GB") response_json = json.dumps(response).encode() indent = "\t" # Backslashes can't be inside of f-strings logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") @@ -552,12 +564,12 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = None logging.debug(f"Starting stream:") - prompt = self.get_prompt_cache(prompt) + prompt = mx.array(self.get_prompt_cache(prompt)) for _, (token, _) in zip( range(self.max_tokens), generate_step( - prompt=mx.array(prompt), + prompt=prompt, model=self.model, temp=self.temperature, top_p=self.top_p, diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 7b440db6..1f1da440 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -34,6 +34,9 @@ MODEL_REMAPPING = { MAX_FILE_SIZE_GB = 5 +# A stream on the default device just for generation +generation_stream = mx.new_stream(mx.default_device()) + class ModelNotFoundError(Exception): def __init__(self, message): @@ -137,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float): - """ - Apply repetition penalty to specific logits based on the given context. - - Paper: https://arxiv.org/abs/1909.05858 - - Args: - logits (mx.array): The logits produced by the language model. - tokens (mx.array): A list of N previous tokens. - penalty (float): The repetition penalty factor to be applied. - - Returns: - logits (mx.array): Logits with repetition penalty applied to generated tokens. - """ - if len(tokens) > 0: - selected_logits = logits[:, tokens] - selected_logits = mx.where( - selected_logits < 0, selected_logits * penalty, selected_logits / penalty - ) - logits[:, tokens] = selected_logits - return logits - - def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): if ( kv_bits is not None @@ -175,17 +155,11 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ def generate_step( prompt: mx.array, model: nn.Module, - temp: float = 0.0, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = 20, - top_p: float = 1.0, - min_p: float = 0.0, - min_tokens_to_keep: int = 1, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, - logit_bias: Optional[Dict[int, float]] = None, - logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, @@ -196,25 +170,15 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating - tokens. - repetition_context_size (int, optional): The number of tokens to - consider for repetition penalty. Default: ``20``. - top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. - min_p (float, optional): The minimum value (scaled by the top token's - probability) that a token probability must have to be considered. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered by min_p sampling. prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. - logit_bias (dictionary, optional): Additive logit bias. - logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): + sampler (Callable[mx.array, mx.array], optional). A function which + takes log probabilities and returns tokens. If ``None`` then the + argmax is used. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. kv_bits (int, optional): Number of bits to use for KV cache quantization. @@ -228,49 +192,6 @@ def generate_step( one token and a vector of log probabilities. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: - logprobs = logits - mx.logsumexp(logits) - - if temp == 0: - token = mx.argmax(logits, axis=-1) - else: - if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) - elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) - else: - token = categorical_sampling(logits, temp) - - return token, logprobs - - if repetition_penalty and ( - repetition_penalty < 0 or not isinstance(repetition_penalty, float) - ): - raise ValueError( - f"repetition_penalty must be a non-negative float, got {repetition_penalty}" - ) - - logits_processor = logits_processor or [] - - if repetition_penalty: - - def repetition_penalty_processor(tokens, logits): - return apply_repetition_penalty( - logits, tokens[-repetition_context_size:], repetition_penalty - ) - - logits_processor.append(repetition_penalty_processor) - - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - - def logit_bias_processor(_, logits): - logits[:, indices] += values - return logits - - logits_processor.append(logit_bias_processor) - y = prompt tokens = None @@ -283,24 +204,27 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + def _step(y): + with mx.stream(generation_stream): + logits = model(y[None], cache=prompt_cache) + logits = logits[:, -1, :] - logits = model(y[None], cache=prompt_cache) - logits = logits[:, -1, :] + if logits_processors: + nonlocal tokens + tokens = mx.concat([tokens, y]) if tokens is not None else y - if logits_processor: - nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y + for processor in logits_processors: + logits = processor(tokens, logits) - for processor in logits_processor: - logits = processor(tokens, logits) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) - - y, logprobs = sample(logits) - return y, logprobs.squeeze(0) + logprobs = logits - mx.logsumexp(logits) + y = sampler(logprobs) + return y, logprobs.squeeze(0) while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) @@ -416,8 +340,7 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - with mx.stream(mx.cpu): - prob = mx.exp(logprobs[token]).item() + prob = mx.exp(logprobs[token]).item() formatter(detokenizer.last_segment, prob) else: print(detokenizer.last_segment, end="", flush=True) @@ -623,7 +546,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): f""" # {upload_repo} - The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. + The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was + converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) + using mlx-lm version **{__version__}**. ## Use with mlx