2024-02-28 00:47:56 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
import copy
|
2024-01-12 04:29:12 +08:00
|
|
|
import glob
|
2024-02-13 02:51:02 +08:00
|
|
|
import importlib
|
2024-01-12 04:29:12 +08:00
|
|
|
import json
|
|
|
|
import logging
|
2024-02-28 00:47:56 +08:00
|
|
|
import shutil
|
2024-01-24 04:44:23 +08:00
|
|
|
import time
|
2024-01-12 04:29:12 +08:00
|
|
|
from pathlib import Path
|
2024-03-01 14:23:01 +08:00
|
|
|
from textwrap import dedent
|
2024-08-29 13:11:45 +08:00
|
|
|
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
from huggingface_hub import snapshot_download
|
2024-02-28 00:47:56 +08:00
|
|
|
from mlx.utils import tree_flatten
|
2024-05-24 10:47:35 +08:00
|
|
|
from transformers import PreTrainedTokenizer
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
# Local imports
|
2024-10-08 11:45:51 +08:00
|
|
|
from .models import base, cache
|
2024-08-16 06:45:02 +08:00
|
|
|
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
2024-04-09 13:36:01 +08:00
|
|
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
2024-03-20 10:50:08 +08:00
|
|
|
from .tuner.utils import dequantize as dequantize_model
|
2024-09-30 08:12:47 +08:00
|
|
|
from .tuner.utils import load_adapters
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
# Constants
|
2024-02-13 02:51:02 +08:00
|
|
|
MODEL_REMAPPING = {
|
|
|
|
"mistral": "llama", # mistral is compatible with llama
|
|
|
|
"phi-msft": "phixtral",
|
2024-01-12 04:29:12 +08:00
|
|
|
}
|
2024-02-13 02:51:02 +08:00
|
|
|
|
2024-01-26 10:59:32 +08:00
|
|
|
MAX_FILE_SIZE_GB = 5
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
|
2024-05-20 21:39:05 +08:00
|
|
|
class ModelNotFoundError(Exception):
|
|
|
|
def __init__(self, message):
|
|
|
|
self.message = message
|
|
|
|
super().__init__(self.message)
|
|
|
|
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
def _get_classes(config: dict):
|
|
|
|
"""
|
|
|
|
Retrieve the model and model args classes based on the configuration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config (dict): The model configuration.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple containing the Model class and the ModelArgs class.
|
|
|
|
"""
|
|
|
|
model_type = config["model_type"]
|
2024-02-13 02:51:02 +08:00
|
|
|
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
|
|
try:
|
|
|
|
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
|
|
|
|
except ImportError:
|
2024-01-12 04:29:12 +08:00
|
|
|
msg = f"Model type {model_type} not supported."
|
|
|
|
logging.error(msg)
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
return arch.Model, arch.ModelArgs
|
|
|
|
|
|
|
|
|
2024-03-02 22:28:26 +08:00
|
|
|
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
|
|
|
Ensures the model is available locally. If the path does not exist locally,
|
|
|
|
it is downloaded from the Hugging Face Hub.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
2024-03-02 22:28:26 +08:00
|
|
|
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Path: The path to the model.
|
|
|
|
"""
|
|
|
|
model_path = Path(path_or_hf_repo)
|
|
|
|
if not model_path.exists():
|
2024-05-20 21:39:05 +08:00
|
|
|
try:
|
|
|
|
model_path = Path(
|
|
|
|
snapshot_download(
|
|
|
|
repo_id=path_or_hf_repo,
|
|
|
|
revision=revision,
|
|
|
|
allow_patterns=[
|
|
|
|
"*.json",
|
|
|
|
"*.safetensors",
|
|
|
|
"*.py",
|
|
|
|
"tokenizer.model",
|
|
|
|
"*.tiktoken",
|
|
|
|
"*.txt",
|
|
|
|
],
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
)
|
2024-09-18 07:22:48 +08:00
|
|
|
except:
|
2024-05-20 21:39:05 +08:00
|
|
|
raise ModelNotFoundError(
|
|
|
|
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
|
|
|
|
"Please make sure you specified the local path or Hugging Face"
|
|
|
|
" repo id correctly.\nIf you are trying to access a private or"
|
|
|
|
" gated Hugging Face repo, make sure you are authenticated:\n"
|
|
|
|
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
|
|
|
|
) from None
|
2024-01-12 04:29:12 +08:00
|
|
|
return model_path
|
|
|
|
|
|
|
|
|
2024-09-30 23:49:03 +08:00
|
|
|
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
|
2024-02-17 13:58:17 +08:00
|
|
|
"""
|
|
|
|
Apply repetition penalty to specific logits based on the given context.
|
|
|
|
|
|
|
|
Paper: https://arxiv.org/abs/1909.05858
|
|
|
|
|
|
|
|
Args:
|
|
|
|
logits (mx.array): The logits produced by the language model.
|
2024-09-30 23:49:03 +08:00
|
|
|
tokens (mx.array): A list of N previous tokens.
|
2024-02-17 13:58:17 +08:00
|
|
|
penalty (float): The repetition penalty factor to be applied.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
|
|
|
"""
|
2024-09-30 23:49:03 +08:00
|
|
|
if len(tokens) > 0:
|
|
|
|
selected_logits = logits[:, tokens]
|
2024-02-17 13:58:17 +08:00
|
|
|
selected_logits = mx.where(
|
|
|
|
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
|
|
|
)
|
2024-09-30 23:49:03 +08:00
|
|
|
logits[:, tokens] = selected_logits
|
2024-02-17 13:58:17 +08:00
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
def generate_step(
|
2024-01-24 04:44:23 +08:00
|
|
|
prompt: mx.array,
|
|
|
|
model: nn.Module,
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
temp: float = 0.0,
|
2024-02-17 13:58:17 +08:00
|
|
|
repetition_penalty: Optional[float] = None,
|
|
|
|
repetition_context_size: Optional[int] = 20,
|
2024-02-26 22:18:11 +08:00
|
|
|
top_p: float = 1.0,
|
2024-08-16 06:45:02 +08:00
|
|
|
min_p: float = 0.0,
|
|
|
|
min_tokens_to_keep: int = 1,
|
2024-08-17 06:28:39 +08:00
|
|
|
prefill_step_size: int = 512,
|
|
|
|
max_kv_size: Optional[int] = None,
|
2024-10-08 11:45:51 +08:00
|
|
|
prompt_cache: Optional[Any] = None,
|
2024-09-29 01:08:49 +08:00
|
|
|
logit_bias: Optional[Dict[int, float]] = None,
|
2024-09-30 23:49:03 +08:00
|
|
|
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
2024-01-24 04:44:23 +08:00
|
|
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
2024-06-04 00:04:39 +08:00
|
|
|
A generator producing token ids based on the given prompt from the model.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
prompt (mx.array): The input prompt.
|
|
|
|
model (nn.Module): The model to use for generation.
|
2024-01-24 04:44:23 +08:00
|
|
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
2024-06-04 00:04:39 +08:00
|
|
|
Default: ``0``.
|
|
|
|
repetition_penalty (float, optional): The penalty factor for repeating
|
|
|
|
tokens.
|
|
|
|
repetition_context_size (int, optional): The number of tokens to
|
|
|
|
consider for repetition penalty. Default: ``20``.
|
|
|
|
top_p (float, optional): Nulceus sampling, higher means model considers
|
|
|
|
more less likely words.
|
2024-08-16 06:45:02 +08:00
|
|
|
min_p (float, optional): The minimum value (scaled by the top token's
|
|
|
|
probability) that a token probability must have to be considered.
|
|
|
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
|
|
|
be filtered by min_p sampling.
|
2024-08-17 06:28:39 +08:00
|
|
|
prefill_step_size (int): Step size for processing the prompt.
|
|
|
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
|
|
|
entries (except the first 4 tokens) will be overwritten.
|
2024-10-08 11:45:51 +08:00
|
|
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
|
|
|
provided, the cache will be updated in place.
|
2024-09-29 01:08:49 +08:00
|
|
|
logit_bias (dictionary, optional): Additive logit bias.
|
2024-09-30 23:49:03 +08:00
|
|
|
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
|
|
|
A list of functions that take tokens and logits and return the processed
|
2024-09-29 01:08:49 +08:00
|
|
|
logits. Default: ``None``.
|
2024-02-17 13:58:17 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
Yields:
|
2024-06-24 01:35:13 +08:00
|
|
|
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
|
|
|
one token and a vector of log probabilities.
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
|
|
|
|
2024-01-23 21:25:44 +08:00
|
|
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
2024-06-24 01:35:13 +08:00
|
|
|
logprobs = logits - mx.logsumexp(logits)
|
2024-01-24 11:47:39 +08:00
|
|
|
|
2024-01-23 21:25:44 +08:00
|
|
|
if temp == 0:
|
|
|
|
token = mx.argmax(logits, axis=-1)
|
|
|
|
else:
|
2024-02-26 22:18:11 +08:00
|
|
|
if top_p > 0 and top_p < 1.0:
|
2024-03-22 03:18:23 +08:00
|
|
|
token = top_p_sampling(logits, top_p, temp)
|
2024-08-16 06:45:02 +08:00
|
|
|
elif min_p != 0.0:
|
|
|
|
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
|
2024-02-26 22:18:11 +08:00
|
|
|
else:
|
2024-08-16 02:29:09 +08:00
|
|
|
token = categorical_sampling(logits, temp)
|
2024-01-24 11:47:39 +08:00
|
|
|
|
2024-06-24 01:35:13 +08:00
|
|
|
return token, logprobs
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-02-17 13:58:17 +08:00
|
|
|
if repetition_penalty and (
|
|
|
|
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
|
|
|
):
|
|
|
|
raise ValueError(
|
|
|
|
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
|
|
|
)
|
|
|
|
|
2024-09-30 23:49:03 +08:00
|
|
|
logits_processor = logits_processor or []
|
|
|
|
|
|
|
|
if repetition_penalty:
|
|
|
|
|
|
|
|
def repetition_penalty_processor(tokens, logits):
|
|
|
|
return apply_repetition_penalty(
|
|
|
|
logits, tokens[-repetition_context_size:], repetition_penalty
|
|
|
|
)
|
|
|
|
|
|
|
|
logits_processor.append(repetition_penalty_processor)
|
|
|
|
|
|
|
|
if logit_bias:
|
|
|
|
indices = mx.array(list(logit_bias.keys()))
|
|
|
|
values = mx.array(list(logit_bias.values()))
|
|
|
|
|
|
|
|
def logit_bias_processor(_, logits):
|
|
|
|
logits[:, indices] += values
|
|
|
|
return logits
|
|
|
|
|
|
|
|
logits_processor.append(logit_bias_processor)
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
y = prompt
|
2024-09-29 01:08:49 +08:00
|
|
|
tokens = None
|
2024-08-29 13:11:45 +08:00
|
|
|
|
|
|
|
# Create the KV cache for generation
|
2024-10-08 11:45:51 +08:00
|
|
|
if prompt_cache is None:
|
|
|
|
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
|
|
|
elif len(prompt_cache) != len(model.layers):
|
|
|
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
2024-02-17 13:58:17 +08:00
|
|
|
|
2024-04-12 04:18:23 +08:00
|
|
|
def _step(y):
|
2024-10-08 11:45:51 +08:00
|
|
|
logits = model(y[None], cache=prompt_cache)
|
2024-01-12 04:29:12 +08:00
|
|
|
logits = logits[:, -1, :]
|
2024-02-17 13:58:17 +08:00
|
|
|
|
2024-09-29 01:08:49 +08:00
|
|
|
if logits_processor:
|
|
|
|
nonlocal tokens
|
|
|
|
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
|
|
|
|
2024-09-30 23:49:03 +08:00
|
|
|
for processor in logits_processor:
|
|
|
|
logits = processor(tokens, logits)
|
2024-02-17 13:58:17 +08:00
|
|
|
|
2024-09-30 23:49:03 +08:00
|
|
|
y, logprobs = sample(logits)
|
2024-06-24 01:35:13 +08:00
|
|
|
return y, logprobs.squeeze(0)
|
2024-04-12 04:18:23 +08:00
|
|
|
|
2024-08-17 06:28:39 +08:00
|
|
|
while y.size > prefill_step_size:
|
|
|
|
model(y[:prefill_step_size][None], cache=cache)
|
|
|
|
mx.eval([c.state for c in cache])
|
|
|
|
y = y[prefill_step_size:]
|
|
|
|
|
2024-06-24 01:35:13 +08:00
|
|
|
y, logprobs = _step(y)
|
2024-04-12 04:18:23 +08:00
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
mx.async_eval(y)
|
2024-04-12 04:18:23 +08:00
|
|
|
while True:
|
2024-06-24 01:35:13 +08:00
|
|
|
next_y, next_logprobs = _step(y)
|
2024-04-19 09:16:10 +08:00
|
|
|
mx.async_eval(next_y)
|
2024-06-24 01:35:13 +08:00
|
|
|
yield y.item(), logprobs
|
|
|
|
y, logprobs = next_y, next_logprobs
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
|
2024-06-04 00:04:39 +08:00
|
|
|
def stream_generate(
|
|
|
|
model: nn.Module,
|
|
|
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
|
|
|
prompt: str,
|
|
|
|
max_tokens: int = 100,
|
|
|
|
**kwargs,
|
|
|
|
) -> Union[str, Generator[str, None, None]]:
|
|
|
|
"""
|
|
|
|
A generator producing text based on the given prompt from the model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
prompt (mx.array): The input prompt.
|
|
|
|
model (nn.Module): The model to use for generation.
|
|
|
|
max_tokens (int): The ma
|
|
|
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
|
|
|
See :func:`generate_step` for more details.
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
|
|
|
"""
|
|
|
|
if not isinstance(tokenizer, TokenizerWrapper):
|
|
|
|
tokenizer = TokenizerWrapper(tokenizer)
|
|
|
|
|
|
|
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
|
|
|
detokenizer = tokenizer.detokenizer
|
|
|
|
|
|
|
|
detokenizer.reset()
|
2024-10-08 11:45:51 +08:00
|
|
|
for n, (token, _) in zip(
|
2024-06-04 00:04:39 +08:00
|
|
|
range(max_tokens),
|
2024-10-08 11:45:51 +08:00
|
|
|
generate_step(prompt_tokens, model, **kwargs),
|
2024-06-04 00:04:39 +08:00
|
|
|
):
|
|
|
|
if token == tokenizer.eos_token_id:
|
|
|
|
break
|
|
|
|
detokenizer.add_token(token)
|
|
|
|
|
|
|
|
# Yield the last segment if streaming
|
|
|
|
yield detokenizer.last_segment
|
|
|
|
|
|
|
|
detokenizer.finalize()
|
|
|
|
yield detokenizer.last_segment
|
|
|
|
|
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
def generate(
|
|
|
|
model: nn.Module,
|
2024-04-09 13:36:01 +08:00
|
|
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
2024-01-13 02:25:56 +08:00
|
|
|
prompt: str,
|
|
|
|
max_tokens: int = 100,
|
|
|
|
verbose: bool = False,
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
formatter: Optional[Callable] = None,
|
2024-06-04 00:04:39 +08:00
|
|
|
**kwargs,
|
|
|
|
) -> Union[str, Generator[str, None, None]]:
|
2024-01-13 02:25:56 +08:00
|
|
|
"""
|
2024-06-04 00:04:39 +08:00
|
|
|
Generate a complete response from the model.
|
2024-01-13 02:25:56 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The language model.
|
|
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
|
|
prompt (str): The string prompt.
|
2024-06-04 00:04:39 +08:00
|
|
|
max_tokens (int): The maximum number of tokens. Default: ``100``.
|
|
|
|
verbose (bool): If ``True``, print tokens and timing information.
|
|
|
|
Default: ``False``.
|
2024-01-24 04:44:23 +08:00
|
|
|
formatter (Optional[Callable]): A function which takes a token and a
|
|
|
|
probability and displays it.
|
2024-06-04 00:04:39 +08:00
|
|
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
|
|
|
See :func:`generate_step` for more details.
|
2024-01-13 02:25:56 +08:00
|
|
|
"""
|
2024-04-09 13:36:01 +08:00
|
|
|
if not isinstance(tokenizer, TokenizerWrapper):
|
|
|
|
tokenizer = TokenizerWrapper(tokenizer)
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-01-24 04:44:23 +08:00
|
|
|
if verbose:
|
|
|
|
print("=" * 10)
|
|
|
|
print("Prompt:", prompt)
|
|
|
|
|
2024-02-17 13:58:17 +08:00
|
|
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
2024-04-09 13:36:01 +08:00
|
|
|
detokenizer = tokenizer.detokenizer
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-01-27 06:11:38 +08:00
|
|
|
tic = time.perf_counter()
|
2024-04-09 13:36:01 +08:00
|
|
|
detokenizer.reset()
|
2024-01-16 23:13:33 +08:00
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
for n, (token, logprobs) in zip(
|
2024-02-17 13:58:17 +08:00
|
|
|
range(max_tokens),
|
2024-10-08 11:45:51 +08:00
|
|
|
generate_step(prompt_tokens, model, **kwargs),
|
2024-02-17 13:58:17 +08:00
|
|
|
):
|
2024-01-24 04:44:23 +08:00
|
|
|
if n == 0:
|
2024-01-27 06:11:38 +08:00
|
|
|
prompt_time = time.perf_counter() - tic
|
|
|
|
tic = time.perf_counter()
|
2024-03-31 04:13:58 +08:00
|
|
|
if token == tokenizer.eos_token_id:
|
|
|
|
break
|
2024-04-09 13:36:01 +08:00
|
|
|
detokenizer.add_token(token)
|
2024-01-13 02:25:56 +08:00
|
|
|
|
|
|
|
if verbose:
|
2024-04-09 13:36:01 +08:00
|
|
|
if formatter:
|
|
|
|
# We have to finalize so that the prob corresponds to the last segment
|
|
|
|
detokenizer.finalize()
|
2024-06-24 01:35:13 +08:00
|
|
|
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
2024-04-09 13:36:01 +08:00
|
|
|
else:
|
|
|
|
print(detokenizer.last_segment, end="", flush=True)
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-03-24 06:32:33 +08:00
|
|
|
token_count = n + 1
|
2024-04-09 13:36:01 +08:00
|
|
|
detokenizer.finalize()
|
2024-01-24 04:44:23 +08:00
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
if verbose:
|
2024-01-27 06:11:38 +08:00
|
|
|
gen_time = time.perf_counter() - tic
|
2024-04-12 04:18:23 +08:00
|
|
|
print(detokenizer.last_segment, flush=True)
|
2024-01-24 04:44:23 +08:00
|
|
|
print("=" * 10)
|
2024-01-31 03:24:16 +08:00
|
|
|
if token_count == 0:
|
2024-01-24 04:44:23 +08:00
|
|
|
print("No tokens generated for this prompt")
|
|
|
|
return
|
2024-02-17 13:58:17 +08:00
|
|
|
prompt_tps = prompt_tokens.size / prompt_time
|
2024-01-31 03:24:16 +08:00
|
|
|
gen_tps = (token_count - 1) / gen_time
|
2024-08-17 06:28:39 +08:00
|
|
|
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
|
|
|
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
|
|
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
|
|
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
2024-01-24 04:44:23 +08:00
|
|
|
|
2024-04-09 13:36:01 +08:00
|
|
|
return detokenizer.text
|
2024-01-13 02:25:56 +08:00
|
|
|
|
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
def load_config(model_path: Path) -> dict:
|
|
|
|
try:
|
|
|
|
with open(model_path / "config.json", "r") as f:
|
|
|
|
config = json.load(f)
|
|
|
|
except FileNotFoundError:
|
|
|
|
logging.error(f"Config file not found in {model_path}")
|
|
|
|
raise
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
2024-05-11 01:13:34 +08:00
|
|
|
def load_model(
|
|
|
|
model_path: Path,
|
|
|
|
lazy: bool = False,
|
|
|
|
model_config: dict = {},
|
2024-07-26 02:01:17 +08:00
|
|
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
2024-05-11 01:13:34 +08:00
|
|
|
) -> nn.Module:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
2024-01-20 13:07:21 +08:00
|
|
|
Load and initialize the model from a given path.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Args:
|
2024-01-20 13:07:21 +08:00
|
|
|
model_path (Path): The path to load the model from.
|
2024-02-21 05:36:55 +08:00
|
|
|
lazy (bool): If False eval the model parameters to make sure they are
|
|
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
|
|
when needed. Default: ``False``
|
2024-07-26 02:01:17 +08:00
|
|
|
model_config (dict, optional): Configuration parameters for the model.
|
2024-05-11 01:13:34 +08:00
|
|
|
Defaults to an empty dictionary.
|
2024-07-26 02:01:17 +08:00
|
|
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
|
|
|
A function that returns the model class and model args class given a config.
|
|
|
|
Defaults to the _get_classes function.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Returns:
|
2024-01-20 13:07:21 +08:00
|
|
|
nn.Module: The loaded and initialized model.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Raises:
|
2024-01-20 13:07:21 +08:00
|
|
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
|
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
2024-04-19 09:16:10 +08:00
|
|
|
|
|
|
|
config = load_config(model_path)
|
2024-05-11 01:13:34 +08:00
|
|
|
config.update(model_config)
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-04-26 05:16:13 +08:00
|
|
|
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
|
|
|
|
|
|
|
if not weight_files:
|
|
|
|
# Try weight for back-compat
|
|
|
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
if not weight_files:
|
|
|
|
logging.error(f"No safetensors found in {model_path}")
|
|
|
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
weights = {}
|
|
|
|
for wf in weight_files:
|
|
|
|
weights.update(mx.load(wf))
|
|
|
|
|
2024-07-26 02:01:17 +08:00
|
|
|
model_class, model_args_class = get_model_classes(config=config)
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
model_args = model_args_class.from_dict(config)
|
|
|
|
model = model_class(model_args)
|
|
|
|
|
2024-03-13 12:34:32 +08:00
|
|
|
if hasattr(model, "sanitize"):
|
|
|
|
weights = model.sanitize(weights)
|
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
if (quantization := config.get("quantization", None)) is not None:
|
|
|
|
# Handle legacy models which may not have everything quantized
|
2024-05-22 06:58:08 +08:00
|
|
|
def class_predicate(p, m):
|
|
|
|
if not hasattr(m, "to_quantized"):
|
|
|
|
return False
|
|
|
|
return f"{p}.scales" in weights
|
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
nn.quantize(
|
|
|
|
model,
|
|
|
|
**quantization,
|
|
|
|
class_predicate=class_predicate,
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
model.load_weights(list(weights.items()))
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
if not lazy:
|
|
|
|
mx.eval(model.parameters())
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
model.eval()
|
2024-01-20 13:07:21 +08:00
|
|
|
return model
|
|
|
|
|
|
|
|
|
2024-01-23 07:00:07 +08:00
|
|
|
def load(
|
2024-02-21 05:36:55 +08:00
|
|
|
path_or_hf_repo: str,
|
|
|
|
tokenizer_config={},
|
2024-05-11 01:13:34 +08:00
|
|
|
model_config={},
|
2024-04-03 04:52:53 +08:00
|
|
|
adapter_path: Optional[str] = None,
|
2024-02-21 05:36:55 +08:00
|
|
|
lazy: bool = False,
|
2024-04-19 05:26:18 +08:00
|
|
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
2024-01-20 13:07:21 +08:00
|
|
|
"""
|
2024-01-24 00:44:37 +08:00
|
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
2024-01-20 13:07:21 +08:00
|
|
|
|
|
|
|
Args:
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
2024-01-23 07:00:07 +08:00
|
|
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
|
|
Defaults to an empty dictionary.
|
2024-05-11 01:13:34 +08:00
|
|
|
model_config(dict, optional): Configuration parameters specifically for the model.
|
|
|
|
Defaults to an empty dictionary.
|
2024-04-03 04:52:53 +08:00
|
|
|
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
|
|
|
to the model. Default: ``None``.
|
2024-02-21 05:36:55 +08:00
|
|
|
lazy (bool): If False eval the model parameters to make sure they are
|
|
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
|
|
when needed. Default: ``False``
|
2024-01-20 13:07:21 +08:00
|
|
|
Returns:
|
2024-04-19 05:26:18 +08:00
|
|
|
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
2024-01-20 13:07:21 +08:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
|
|
ValueError: If model class or args class are not found.
|
|
|
|
"""
|
|
|
|
model_path = get_model_path(path_or_hf_repo)
|
|
|
|
|
2024-05-11 01:13:34 +08:00
|
|
|
model = load_model(model_path, lazy, model_config)
|
2024-04-03 04:52:53 +08:00
|
|
|
if adapter_path is not None:
|
2024-09-30 08:12:47 +08:00
|
|
|
model = load_adapters(model, adapter_path)
|
2024-01-25 00:11:25 +08:00
|
|
|
model.eval()
|
2024-04-09 13:36:01 +08:00
|
|
|
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
2024-01-24 11:47:39 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
return model, tokenizer
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
def fetch_from_hub(
|
2024-02-21 05:36:55 +08:00
|
|
|
model_path: Path, lazy: bool = False
|
2024-02-28 11:40:42 +08:00
|
|
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
2024-02-21 05:36:55 +08:00
|
|
|
model = load_model(model_path, lazy)
|
2024-04-19 09:16:10 +08:00
|
|
|
config = load_config(model_path)
|
2024-04-09 13:36:01 +08:00
|
|
|
tokenizer = load_tokenizer(model_path)
|
2024-04-19 09:16:10 +08:00
|
|
|
return model, config, tokenizer
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
|
|
|
"""
|
|
|
|
Splits the weights into smaller shards.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
weights (dict): Model weights.
|
|
|
|
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list: List of weight shards.
|
|
|
|
"""
|
|
|
|
max_file_size_bytes = max_file_size_gb << 30
|
|
|
|
shards = []
|
|
|
|
shard, shard_size = {}, 0
|
|
|
|
for k, v in weights.items():
|
2024-01-27 05:54:49 +08:00
|
|
|
if shard_size + v.nbytes > max_file_size_bytes:
|
2024-01-24 00:44:37 +08:00
|
|
|
shards.append(shard)
|
|
|
|
shard, shard_size = {}, 0
|
|
|
|
shard[k] = v
|
2024-01-27 05:54:49 +08:00
|
|
|
shard_size += v.nbytes
|
2024-01-24 00:44:37 +08:00
|
|
|
shards.append(shard)
|
|
|
|
return shards
|
|
|
|
|
|
|
|
|
|
|
|
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|
|
|
"""
|
|
|
|
Uploads the model to Hugging Face hub.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path (str): Local path to the model.
|
|
|
|
upload_repo (str): Name of the HF repo to upload to.
|
|
|
|
hf_path (str): Path to the original Hugging Face model.
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
|
|
|
|
from huggingface_hub import HfApi, ModelCard, logging
|
|
|
|
|
2024-03-20 10:42:03 +08:00
|
|
|
from . import __version__
|
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
card = ModelCard.load(hf_path)
|
|
|
|
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
2024-09-04 21:19:32 +08:00
|
|
|
card.data.base_model = hf_path
|
2024-03-01 14:23:01 +08:00
|
|
|
card.text = dedent(
|
|
|
|
f"""
|
|
|
|
# {upload_repo}
|
2024-05-08 23:18:13 +08:00
|
|
|
|
2024-05-03 12:22:04 +08:00
|
|
|
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
2024-05-08 23:18:13 +08:00
|
|
|
|
2024-03-01 14:23:01 +08:00
|
|
|
## Use with mlx
|
|
|
|
|
|
|
|
```bash
|
|
|
|
pip install mlx-lm
|
|
|
|
```
|
|
|
|
|
|
|
|
```python
|
|
|
|
from mlx_lm import load, generate
|
|
|
|
|
|
|
|
model, tokenizer = load("{upload_repo}")
|
2024-09-07 21:06:15 +08:00
|
|
|
|
|
|
|
prompt="hello"
|
|
|
|
|
|
|
|
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
|
2024-09-08 05:46:57 +08:00
|
|
|
messages = [{{"role": "user", "content": prompt}}]
|
2024-09-07 21:06:15 +08:00
|
|
|
prompt = tokenizer.apply_chat_template(
|
|
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
|
|
)
|
|
|
|
|
|
|
|
response = generate(model, tokenizer, prompt=prompt, verbose=True)
|
2024-03-01 14:23:01 +08:00
|
|
|
```
|
|
|
|
"""
|
|
|
|
)
|
2024-01-24 00:44:37 +08:00
|
|
|
card.save(os.path.join(path, "README.md"))
|
|
|
|
|
|
|
|
logging.set_verbosity_info()
|
|
|
|
|
|
|
|
api = HfApi()
|
|
|
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
|
|
|
api.upload_folder(
|
|
|
|
folder_path=path,
|
|
|
|
repo_id=upload_repo,
|
|
|
|
repo_type="model",
|
2024-04-29 10:07:17 +08:00
|
|
|
multi_commits=True,
|
|
|
|
multi_commits_verbose=True,
|
2024-01-24 00:44:37 +08:00
|
|
|
)
|
2024-03-06 13:51:31 +08:00
|
|
|
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
def save_weights(
|
|
|
|
save_path: Union[str, Path],
|
|
|
|
weights: Dict[str, Any],
|
|
|
|
*,
|
|
|
|
donate_weights: bool = False,
|
|
|
|
) -> None:
|
2024-01-24 00:44:37 +08:00
|
|
|
"""Save model weights into specified directory."""
|
|
|
|
if isinstance(save_path, str):
|
|
|
|
save_path = Path(save_path)
|
|
|
|
save_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
shards = make_shards(weights)
|
|
|
|
shards_count = len(shards)
|
|
|
|
shard_file_format = (
|
|
|
|
"model-{:05d}-of-{:05d}.safetensors"
|
|
|
|
if shards_count > 1
|
|
|
|
else "model.safetensors"
|
|
|
|
)
|
|
|
|
|
2024-02-06 21:32:15 +08:00
|
|
|
total_size = sum(v.nbytes for v in weights.values())
|
|
|
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
# Write the weights and make sure no references are kept other than the
|
|
|
|
# necessary ones
|
|
|
|
if donate_weights:
|
|
|
|
weights.clear()
|
2024-03-01 01:40:04 +08:00
|
|
|
del weights
|
2024-02-21 05:36:55 +08:00
|
|
|
|
|
|
|
for i in range(len(shards)):
|
|
|
|
shard = shards[i]
|
|
|
|
shards[i] = None
|
2024-01-24 00:44:37 +08:00
|
|
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
2024-02-06 21:32:15 +08:00
|
|
|
shard_path = save_path / shard_name
|
|
|
|
|
2024-02-28 23:29:00 +08:00
|
|
|
mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})
|
2024-02-06 21:32:15 +08:00
|
|
|
|
|
|
|
for weight_name in shard.keys():
|
|
|
|
index_data["weight_map"][weight_name] = shard_name
|
2024-02-21 05:36:55 +08:00
|
|
|
del shard
|
2024-02-06 21:32:15 +08:00
|
|
|
|
|
|
|
index_data["weight_map"] = {
|
|
|
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
|
|
|
}
|
|
|
|
|
|
|
|
with open(save_path / "model.safetensors.index.json", "w") as f:
|
|
|
|
json.dump(
|
|
|
|
index_data,
|
|
|
|
f,
|
|
|
|
indent=4,
|
|
|
|
)
|
2024-02-28 00:47:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
def quantize_model(
|
|
|
|
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
|
|
|
) -> Tuple:
|
|
|
|
"""
|
|
|
|
Applies quantization to the model weights.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The model to be quantized.
|
|
|
|
config (dict): Model configuration.
|
|
|
|
q_group_size (int): Group size for quantization.
|
|
|
|
q_bits (int): Bits per weight for quantization.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple: Tuple containing quantized weights and config.
|
|
|
|
"""
|
|
|
|
quantized_config = copy.deepcopy(config)
|
2024-04-19 09:16:10 +08:00
|
|
|
nn.quantize(model, q_group_size, q_bits)
|
2024-02-28 00:47:56 +08:00
|
|
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
2024-09-04 21:19:32 +08:00
|
|
|
# support hf model tree #957
|
|
|
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
2024-02-28 00:47:56 +08:00
|
|
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
|
|
|
|
|
|
return quantized_weights, quantized_config
|
|
|
|
|
|
|
|
|
2024-03-14 21:36:05 +08:00
|
|
|
def save_config(
|
|
|
|
config: dict,
|
|
|
|
config_path: Union[str, Path],
|
|
|
|
) -> None:
|
|
|
|
"""Save the model configuration to the ``config_path``.
|
|
|
|
|
|
|
|
The final configuration will be sorted before saving for better readability.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config (dict): The model configuration.
|
|
|
|
config_path (Union[str, Path]): Model configuration file path.
|
|
|
|
"""
|
|
|
|
# Clean unused keys
|
|
|
|
config.pop("_name_or_path", None)
|
|
|
|
|
|
|
|
# sort the config for better readability
|
|
|
|
config = dict(sorted(config.items()))
|
|
|
|
|
|
|
|
# write the updated config to the config_path (if provided)
|
|
|
|
with open(config_path, "w") as fid:
|
|
|
|
json.dump(config, fid, indent=4)
|
|
|
|
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
def convert(
|
|
|
|
hf_path: str,
|
|
|
|
mlx_path: str = "mlx_model",
|
|
|
|
quantize: bool = False,
|
|
|
|
q_group_size: int = 64,
|
|
|
|
q_bits: int = 4,
|
|
|
|
dtype: str = "float16",
|
|
|
|
upload_repo: str = None,
|
2024-03-02 22:28:26 +08:00
|
|
|
revision: Optional[str] = None,
|
2024-03-20 10:50:08 +08:00
|
|
|
dequantize: bool = False,
|
2024-02-28 00:47:56 +08:00
|
|
|
):
|
2024-08-15 01:22:04 +08:00
|
|
|
# Check the save path is empty
|
|
|
|
if isinstance(mlx_path, str):
|
|
|
|
mlx_path = Path(mlx_path)
|
|
|
|
|
|
|
|
if mlx_path.exists():
|
|
|
|
raise ValueError(
|
|
|
|
f"Cannot save to the path {mlx_path} as it already exists."
|
|
|
|
" Please delete the file/directory or specify a new path to save to."
|
|
|
|
)
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
print("[INFO] Loading")
|
2024-03-02 22:28:26 +08:00
|
|
|
model_path = get_model_path(hf_path, revision=revision)
|
2024-02-28 00:47:56 +08:00
|
|
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
|
|
|
|
|
|
|
weights = dict(tree_flatten(model.parameters()))
|
|
|
|
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
|
|
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
|
|
|
|
2024-03-20 10:50:08 +08:00
|
|
|
if quantize and dequantize:
|
|
|
|
raise ValueError("Choose either quantize or dequantize, not both.")
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
if quantize:
|
|
|
|
print("[INFO] Quantizing")
|
|
|
|
model.load_weights(list(weights.items()))
|
|
|
|
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
|
|
|
|
2024-03-20 10:50:08 +08:00
|
|
|
if dequantize:
|
|
|
|
print("[INFO] Dequantizing")
|
|
|
|
model = dequantize_model(model)
|
|
|
|
weights = dict(tree_flatten(model.parameters()))
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
del model
|
|
|
|
save_weights(mlx_path, weights, donate_weights=True)
|
|
|
|
|
|
|
|
py_files = glob.glob(str(model_path / "*.py"))
|
|
|
|
for file in py_files:
|
|
|
|
shutil.copy(file, mlx_path)
|
|
|
|
|
|
|
|
tokenizer.save_pretrained(mlx_path)
|
|
|
|
|
2024-03-14 21:36:05 +08:00
|
|
|
save_config(config, config_path=mlx_path / "config.json")
|
2024-02-28 00:47:56 +08:00
|
|
|
|
|
|
|
if upload_repo is not None:
|
|
|
|
upload_to_hub(mlx_path, upload_repo, hf_path)
|