2024-02-28 00:47:56 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
import copy
|
2024-01-12 04:29:12 +08:00
|
|
|
import glob
|
2024-02-13 02:51:02 +08:00
|
|
|
import importlib
|
2024-01-12 04:29:12 +08:00
|
|
|
import json
|
|
|
|
import logging
|
2024-02-28 00:47:56 +08:00
|
|
|
import shutil
|
2024-01-24 04:44:23 +08:00
|
|
|
import time
|
2024-01-12 04:29:12 +08:00
|
|
|
from pathlib import Path
|
2024-03-01 14:23:01 +08:00
|
|
|
from textwrap import dedent
|
2024-04-03 04:52:53 +08:00
|
|
|
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
from huggingface_hub import snapshot_download
|
2024-02-28 00:47:56 +08:00
|
|
|
from mlx.utils import tree_flatten
|
2024-04-19 09:16:10 +08:00
|
|
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
# Local imports
|
2024-04-09 13:36:01 +08:00
|
|
|
from .sample_utils import top_p_sampling
|
|
|
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
2024-01-24 11:47:39 +08:00
|
|
|
from .tuner.utils import apply_lora_layers
|
2024-03-20 10:50:08 +08:00
|
|
|
from .tuner.utils import dequantize as dequantize_model
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
# Constants
|
2024-02-13 02:51:02 +08:00
|
|
|
MODEL_REMAPPING = {
|
|
|
|
"mistral": "llama", # mistral is compatible with llama
|
|
|
|
"phi-msft": "phixtral",
|
2024-01-12 04:29:12 +08:00
|
|
|
}
|
2024-02-13 02:51:02 +08:00
|
|
|
|
2024-01-26 10:59:32 +08:00
|
|
|
MAX_FILE_SIZE_GB = 5
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
def _get_classes(config: dict):
|
|
|
|
"""
|
|
|
|
Retrieve the model and model args classes based on the configuration.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config (dict): The model configuration.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple containing the Model class and the ModelArgs class.
|
|
|
|
"""
|
|
|
|
model_type = config["model_type"]
|
2024-02-13 02:51:02 +08:00
|
|
|
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
|
|
try:
|
|
|
|
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
|
|
|
|
except ImportError:
|
2024-01-12 04:29:12 +08:00
|
|
|
msg = f"Model type {model_type} not supported."
|
|
|
|
logging.error(msg)
|
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
|
|
return arch.Model, arch.ModelArgs
|
|
|
|
|
|
|
|
|
2024-03-02 22:28:26 +08:00
|
|
|
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
|
|
|
Ensures the model is available locally. If the path does not exist locally,
|
|
|
|
it is downloaded from the Hugging Face Hub.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
|
2024-03-02 22:28:26 +08:00
|
|
|
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
Path: The path to the model.
|
|
|
|
"""
|
|
|
|
model_path = Path(path_or_hf_repo)
|
|
|
|
if not model_path.exists():
|
|
|
|
model_path = Path(
|
|
|
|
snapshot_download(
|
|
|
|
repo_id=path_or_hf_repo,
|
2024-03-02 22:28:26 +08:00
|
|
|
revision=revision,
|
2024-01-23 07:00:07 +08:00
|
|
|
allow_patterns=[
|
|
|
|
"*.json",
|
|
|
|
"*.safetensors",
|
|
|
|
"*.py",
|
|
|
|
"tokenizer.model",
|
|
|
|
"*.tiktoken",
|
2024-03-15 12:35:54 +08:00
|
|
|
"*.txt",
|
2024-01-23 07:00:07 +08:00
|
|
|
],
|
2024-01-12 04:29:12 +08:00
|
|
|
)
|
|
|
|
)
|
|
|
|
return model_path
|
|
|
|
|
|
|
|
|
2024-02-17 13:58:17 +08:00
|
|
|
def apply_repetition_penalty(logits: mx.array, generated_tokens: Any, penalty: float):
|
|
|
|
"""
|
|
|
|
Apply repetition penalty to specific logits based on the given context.
|
|
|
|
|
|
|
|
Paper: https://arxiv.org/abs/1909.05858
|
|
|
|
|
|
|
|
Args:
|
|
|
|
logits (mx.array): The logits produced by the language model.
|
|
|
|
generated_tokens (any): A list of N previous tokens.
|
|
|
|
penalty (float): The repetition penalty factor to be applied.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
logits (mx.array): Logits with repetition penalty applied to generated tokens.
|
|
|
|
"""
|
|
|
|
if len(generated_tokens) > 0:
|
|
|
|
indices = mx.array([token for token in generated_tokens])
|
|
|
|
selected_logits = logits[:, indices]
|
|
|
|
selected_logits = mx.where(
|
|
|
|
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
|
|
|
|
)
|
|
|
|
logits[:, indices] = selected_logits
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
def generate_step(
|
2024-01-24 04:44:23 +08:00
|
|
|
prompt: mx.array,
|
|
|
|
model: nn.Module,
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
temp: float = 0.0,
|
2024-02-17 13:58:17 +08:00
|
|
|
repetition_penalty: Optional[float] = None,
|
|
|
|
repetition_context_size: Optional[int] = 20,
|
2024-02-26 22:18:11 +08:00
|
|
|
top_p: float = 1.0,
|
2024-04-21 21:53:56 +08:00
|
|
|
logit_bias: Optional[Dict[int, float]] = None,
|
2024-01-24 04:44:23 +08:00
|
|
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
2024-01-13 02:25:56 +08:00
|
|
|
A generator producing text based on the given prompt from the model.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
prompt (mx.array): The input prompt.
|
|
|
|
model (nn.Module): The model to use for generation.
|
2024-01-24 04:44:23 +08:00
|
|
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
2024-02-17 13:58:17 +08:00
|
|
|
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
|
|
|
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty (default 20).
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
top_p (float, optional): Nulceus sampling, higher means model considers more less likely words
|
2024-02-17 13:58:17 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
Yields:
|
2024-01-24 04:44:23 +08:00
|
|
|
Generator[Tuple[mx.array, mx.array]]: A generator producing
|
|
|
|
one token and probability per call.
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
|
|
|
|
2024-01-23 21:25:44 +08:00
|
|
|
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
2024-04-30 22:27:40 +08:00
|
|
|
if logit_bias:
|
|
|
|
indices = mx.array(list(logit_bias.keys()))
|
|
|
|
values = mx.array(list(logit_bias.values()))
|
|
|
|
logits[:, indices] += values
|
2024-01-24 11:47:39 +08:00
|
|
|
softmax_logits = mx.softmax(logits)
|
|
|
|
|
2024-01-23 21:25:44 +08:00
|
|
|
if temp == 0:
|
|
|
|
token = mx.argmax(logits, axis=-1)
|
|
|
|
else:
|
2024-02-26 22:18:11 +08:00
|
|
|
if top_p > 0 and top_p < 1.0:
|
2024-03-22 03:18:23 +08:00
|
|
|
token = top_p_sampling(logits, top_p, temp)
|
2024-02-26 22:18:11 +08:00
|
|
|
else:
|
|
|
|
token = mx.random.categorical(logits * (1 / temp))
|
2024-01-24 11:47:39 +08:00
|
|
|
|
|
|
|
prob = softmax_logits[0, token]
|
2024-01-24 04:44:23 +08:00
|
|
|
return token, prob
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-02-17 13:58:17 +08:00
|
|
|
if repetition_penalty and (
|
|
|
|
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
|
|
|
):
|
|
|
|
raise ValueError(
|
|
|
|
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
|
|
|
)
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
y = prompt
|
|
|
|
cache = None
|
2024-02-17 13:58:17 +08:00
|
|
|
|
|
|
|
repetition_context = prompt.tolist()
|
|
|
|
|
|
|
|
if repetition_context_size:
|
|
|
|
repetition_context = repetition_context[-repetition_context_size:]
|
|
|
|
|
2024-04-12 04:18:23 +08:00
|
|
|
def _step(y):
|
|
|
|
nonlocal cache, repetition_context
|
2024-01-12 04:29:12 +08:00
|
|
|
logits, cache = model(y[None], cache=cache)
|
|
|
|
logits = logits[:, -1, :]
|
2024-02-17 13:58:17 +08:00
|
|
|
|
|
|
|
if repetition_penalty:
|
|
|
|
logits = apply_repetition_penalty(
|
|
|
|
logits, repetition_context, repetition_penalty
|
|
|
|
)
|
|
|
|
y, prob = sample(logits)
|
|
|
|
repetition_context.append(y.item())
|
|
|
|
else:
|
|
|
|
y, prob = sample(logits)
|
|
|
|
|
|
|
|
if repetition_context_size:
|
|
|
|
if len(repetition_context) > repetition_context_size:
|
|
|
|
repetition_context = repetition_context[-repetition_context_size:]
|
2024-04-12 04:18:23 +08:00
|
|
|
return y, prob
|
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
y, p = _step(y)
|
2024-04-12 04:18:23 +08:00
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
mx.async_eval(y)
|
2024-04-12 04:18:23 +08:00
|
|
|
while True:
|
2024-04-19 09:16:10 +08:00
|
|
|
next_y, next_p = _step(y)
|
|
|
|
mx.async_eval(next_y)
|
|
|
|
yield y.item(), p
|
|
|
|
y, p = next_y, next_p
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
def generate(
|
|
|
|
model: nn.Module,
|
2024-04-09 13:36:01 +08:00
|
|
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
2024-01-13 02:25:56 +08:00
|
|
|
prompt: str,
|
|
|
|
temp: float = 0.0,
|
|
|
|
max_tokens: int = 100,
|
|
|
|
verbose: bool = False,
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
formatter: Optional[Callable] = None,
|
2024-02-17 13:58:17 +08:00
|
|
|
repetition_penalty: Optional[float] = None,
|
|
|
|
repetition_context_size: Optional[int] = None,
|
2024-02-26 22:18:11 +08:00
|
|
|
top_p: float = 1.0,
|
2024-04-21 21:53:56 +08:00
|
|
|
logit_bias: Optional[Dict[int, float]] = None,
|
2024-01-13 02:25:56 +08:00
|
|
|
) -> str:
|
|
|
|
"""
|
|
|
|
Generate text from the model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The language model.
|
|
|
|
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
|
|
prompt (str): The string prompt.
|
|
|
|
temp (float): The temperature for sampling (default 0).
|
|
|
|
max_tokens (int): The maximum number of tokens (default 100).
|
2024-01-24 04:44:23 +08:00
|
|
|
verbose (bool): If ``True``, print tokens and timing information
|
|
|
|
(default ``False``).
|
|
|
|
formatter (Optional[Callable]): A function which takes a token and a
|
|
|
|
probability and displays it.
|
2024-02-17 13:58:17 +08:00
|
|
|
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
|
|
|
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
2024-01-13 02:25:56 +08:00
|
|
|
"""
|
2024-04-09 13:36:01 +08:00
|
|
|
if not isinstance(tokenizer, TokenizerWrapper):
|
|
|
|
tokenizer = TokenizerWrapper(tokenizer)
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-01-24 04:44:23 +08:00
|
|
|
if verbose:
|
|
|
|
print("=" * 10)
|
|
|
|
print("Prompt:", prompt)
|
|
|
|
|
2024-02-17 13:58:17 +08:00
|
|
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
2024-04-09 13:36:01 +08:00
|
|
|
detokenizer = tokenizer.detokenizer
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-01-27 06:11:38 +08:00
|
|
|
tic = time.perf_counter()
|
2024-04-09 13:36:01 +08:00
|
|
|
detokenizer.reset()
|
2024-01-16 23:13:33 +08:00
|
|
|
|
2024-02-17 13:58:17 +08:00
|
|
|
for (token, prob), n in zip(
|
|
|
|
generate_step(
|
|
|
|
prompt_tokens,
|
|
|
|
model,
|
|
|
|
temp,
|
|
|
|
repetition_penalty,
|
|
|
|
repetition_context_size,
|
2024-02-26 22:18:11 +08:00
|
|
|
top_p,
|
2024-04-21 21:53:56 +08:00
|
|
|
logit_bias,
|
2024-02-17 13:58:17 +08:00
|
|
|
),
|
|
|
|
range(max_tokens),
|
|
|
|
):
|
2024-01-24 04:44:23 +08:00
|
|
|
if n == 0:
|
2024-01-27 06:11:38 +08:00
|
|
|
prompt_time = time.perf_counter() - tic
|
|
|
|
tic = time.perf_counter()
|
2024-03-31 04:13:58 +08:00
|
|
|
if token == tokenizer.eos_token_id:
|
|
|
|
break
|
2024-04-09 13:36:01 +08:00
|
|
|
detokenizer.add_token(token)
|
2024-01-13 02:25:56 +08:00
|
|
|
|
|
|
|
if verbose:
|
2024-04-09 13:36:01 +08:00
|
|
|
if formatter:
|
|
|
|
# We have to finalize so that the prob corresponds to the last segment
|
|
|
|
detokenizer.finalize()
|
|
|
|
formatter(detokenizer.last_segment, prob.item())
|
|
|
|
else:
|
|
|
|
print(detokenizer.last_segment, end="", flush=True)
|
2024-01-13 02:25:56 +08:00
|
|
|
|
2024-03-24 06:32:33 +08:00
|
|
|
token_count = n + 1
|
2024-04-09 13:36:01 +08:00
|
|
|
detokenizer.finalize()
|
2024-01-24 04:44:23 +08:00
|
|
|
|
2024-01-13 02:25:56 +08:00
|
|
|
if verbose:
|
2024-01-27 06:11:38 +08:00
|
|
|
gen_time = time.perf_counter() - tic
|
2024-04-12 04:18:23 +08:00
|
|
|
print(detokenizer.last_segment, flush=True)
|
2024-01-24 04:44:23 +08:00
|
|
|
print("=" * 10)
|
2024-01-31 03:24:16 +08:00
|
|
|
if token_count == 0:
|
2024-01-24 04:44:23 +08:00
|
|
|
print("No tokens generated for this prompt")
|
|
|
|
return
|
2024-02-17 13:58:17 +08:00
|
|
|
prompt_tps = prompt_tokens.size / prompt_time
|
2024-01-31 03:24:16 +08:00
|
|
|
gen_tps = (token_count - 1) / gen_time
|
2024-01-24 04:44:23 +08:00
|
|
|
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
|
|
|
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
|
|
|
|
2024-04-09 13:36:01 +08:00
|
|
|
return detokenizer.text
|
2024-01-13 02:25:56 +08:00
|
|
|
|
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
def load_config(model_path: Path) -> dict:
|
|
|
|
try:
|
|
|
|
with open(model_path / "config.json", "r") as f:
|
|
|
|
config = json.load(f)
|
|
|
|
except FileNotFoundError:
|
|
|
|
logging.error(f"Config file not found in {model_path}")
|
|
|
|
raise
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
2024-01-20 13:07:21 +08:00
|
|
|
Load and initialize the model from a given path.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Args:
|
2024-01-20 13:07:21 +08:00
|
|
|
model_path (Path): The path to load the model from.
|
2024-02-21 05:36:55 +08:00
|
|
|
lazy (bool): If False eval the model parameters to make sure they are
|
|
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
|
|
when needed. Default: ``False``
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Returns:
|
2024-01-20 13:07:21 +08:00
|
|
|
nn.Module: The loaded and initialized model.
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
Raises:
|
2024-01-20 13:07:21 +08:00
|
|
|
FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
|
|
ValueError: If the model class or args class are not found or cannot be instantiated.
|
2024-01-12 04:29:12 +08:00
|
|
|
"""
|
2024-04-19 09:16:10 +08:00
|
|
|
|
|
|
|
config = load_config(model_path)
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-04-26 05:16:13 +08:00
|
|
|
weight_files = glob.glob(str(model_path / "model*.safetensors"))
|
|
|
|
|
|
|
|
if not weight_files:
|
|
|
|
# Try weight for back-compat
|
|
|
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
if not weight_files:
|
|
|
|
logging.error(f"No safetensors found in {model_path}")
|
|
|
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
weights = {}
|
|
|
|
for wf in weight_files:
|
|
|
|
weights.update(mx.load(wf))
|
|
|
|
|
|
|
|
model_class, model_args_class = _get_classes(config=config)
|
|
|
|
|
|
|
|
model_args = model_args_class.from_dict(config)
|
|
|
|
model = model_class(model_args)
|
|
|
|
|
2024-03-13 12:34:32 +08:00
|
|
|
if hasattr(model, "sanitize"):
|
|
|
|
weights = model.sanitize(weights)
|
|
|
|
|
2024-04-19 09:16:10 +08:00
|
|
|
if (quantization := config.get("quantization", None)) is not None:
|
|
|
|
# Handle legacy models which may not have everything quantized
|
|
|
|
class_predicate = (
|
|
|
|
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
|
|
|
|
and f"{p}.scales" in weights
|
|
|
|
)
|
|
|
|
nn.quantize(
|
|
|
|
model,
|
|
|
|
**quantization,
|
|
|
|
class_predicate=class_predicate,
|
|
|
|
)
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
model.load_weights(list(weights.items()))
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
if not lazy:
|
|
|
|
mx.eval(model.parameters())
|
2024-01-20 13:07:21 +08:00
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
model.eval()
|
2024-01-20 13:07:21 +08:00
|
|
|
return model
|
|
|
|
|
|
|
|
|
2024-01-23 07:00:07 +08:00
|
|
|
def load(
|
2024-02-21 05:36:55 +08:00
|
|
|
path_or_hf_repo: str,
|
|
|
|
tokenizer_config={},
|
2024-04-03 04:52:53 +08:00
|
|
|
adapter_path: Optional[str] = None,
|
2024-02-21 05:36:55 +08:00
|
|
|
lazy: bool = False,
|
2024-04-19 05:26:18 +08:00
|
|
|
) -> Tuple[nn.Module, TokenizerWrapper]:
|
2024-01-20 13:07:21 +08:00
|
|
|
"""
|
2024-01-24 00:44:37 +08:00
|
|
|
Load the model and tokenizer from a given path or a huggingface repository.
|
2024-01-20 13:07:21 +08:00
|
|
|
|
|
|
|
Args:
|
Refactoring of mlx_lm example (#501)
* Use named tuple from typing for typehints
* Add type hints
* Simplify expression
* Type hint fix
* Improved do_POST logic
Use a map of endpoints to methods to reduce redundancy in code
* Fix format
* Improve redundancy
Call method dynamically instead of writing out all arguments twice
* Send response instead of returning
* Fix typo
* Revert change
* Make adapter_file as Optional
* Mark formatter as optional
* format
* Create message generator
Store response data that stays static for the duration of the response inside of the object:
system_fingerprint
request_id
object_type
requested_model
Created a message generator, that dynamically creates messages from the metadata stored inside of the object, and the data from the model pipeline
* Remove leftover
* Update parameters to reflect new object structure
No longer pass all arguments between functions, but use the stores values inside of the object
* Parse body before calling request specific methods
* Call super init
* Update server.py
* Fixed outdated documentation parameter name
* Add documentation
* Fix sending headers twice
During testing I found that when using the streaming option, headers have always been sent twice. This should fix that
* Simplify streaming code by using guard clauses
Don't wrap wfile writes in try blocks, the server class has its own try block to prevent crashing
* Bug fix
* Use Content-Length header
Let the completion type specific methods finish sending the headers. This allows us to send the Content-Length header as the model returns a completion.
* Update utils.py
* Add top_p documentation
* Type hint model and tokenizer as required
* Use static system fingerprint
System fingerprint now stays the same across requests
* Make type hint more specific
* Bug Fix
Supplying less than 2 models to merge would raise ValueError and calls len on unbound "models". Should be "model_paths" instead.
Mark upload_repo as optional
* Move more of the shared code into do_POST
Processing stop_id_sequences is done no matter the request endpoint or type, move it into the shared section. handle_ methods now just return the prompt in mx.array form.
* Store stop_id_sequences as lists instead of np
During testing I found that letting the tokenizer return values as python lists and converting them to mlx arrays was around 20% faster than having the tokenizer convert them to np, and from np to mlx. This allows makes it so numpy no longer needs to be imported.
* Update stop_id_sequences docs
* Turn if check to non-inclusive
Only continue if buffer is smaller
* Documentation fix
* Cleared method names
Instead of handle_stream and generate_competion, we should name it handle_completion.
Instead of handle_completions and handle_chat_completions, we should name it handle_text_completions, since both are completions, calling it text completions should make it more descriptive
* Make comment clearer
* fix format
* format
2024-03-06 22:24:31 +08:00
|
|
|
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
|
2024-01-23 07:00:07 +08:00
|
|
|
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
|
|
|
|
Defaults to an empty dictionary.
|
2024-04-03 04:52:53 +08:00
|
|
|
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
|
|
|
|
to the model. Default: ``None``.
|
2024-02-21 05:36:55 +08:00
|
|
|
lazy (bool): If False eval the model parameters to make sure they are
|
|
|
|
loaded in memory before returning, otherwise they will be loaded
|
|
|
|
when needed. Default: ``False``
|
2024-01-20 13:07:21 +08:00
|
|
|
Returns:
|
2024-04-19 05:26:18 +08:00
|
|
|
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
|
2024-01-20 13:07:21 +08:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
FileNotFoundError: If config file or safetensors are not found.
|
|
|
|
ValueError: If model class or args class are not found.
|
|
|
|
"""
|
|
|
|
model_path = get_model_path(path_or_hf_repo)
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
model = load_model(model_path, lazy)
|
2024-04-03 04:52:53 +08:00
|
|
|
if adapter_path is not None:
|
|
|
|
model = apply_lora_layers(model, adapter_path)
|
2024-01-25 00:11:25 +08:00
|
|
|
model.eval()
|
2024-04-09 13:36:01 +08:00
|
|
|
tokenizer = load_tokenizer(model_path, tokenizer_config)
|
2024-01-24 11:47:39 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
return model, tokenizer
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
def fetch_from_hub(
|
2024-02-21 05:36:55 +08:00
|
|
|
model_path: Path, lazy: bool = False
|
2024-02-28 11:40:42 +08:00
|
|
|
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
|
2024-02-21 05:36:55 +08:00
|
|
|
model = load_model(model_path, lazy)
|
2024-04-19 09:16:10 +08:00
|
|
|
config = load_config(model_path)
|
2024-04-09 13:36:01 +08:00
|
|
|
tokenizer = load_tokenizer(model_path)
|
2024-04-19 09:16:10 +08:00
|
|
|
return model, config, tokenizer
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
|
|
|
"""
|
|
|
|
Splits the weights into smaller shards.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
weights (dict): Model weights.
|
|
|
|
max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
list: List of weight shards.
|
|
|
|
"""
|
|
|
|
max_file_size_bytes = max_file_size_gb << 30
|
|
|
|
shards = []
|
|
|
|
shard, shard_size = {}, 0
|
|
|
|
for k, v in weights.items():
|
2024-01-27 05:54:49 +08:00
|
|
|
if shard_size + v.nbytes > max_file_size_bytes:
|
2024-01-24 00:44:37 +08:00
|
|
|
shards.append(shard)
|
|
|
|
shard, shard_size = {}, 0
|
|
|
|
shard[k] = v
|
2024-01-27 05:54:49 +08:00
|
|
|
shard_size += v.nbytes
|
2024-01-24 00:44:37 +08:00
|
|
|
shards.append(shard)
|
|
|
|
return shards
|
|
|
|
|
|
|
|
|
|
|
|
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|
|
|
"""
|
|
|
|
Uploads the model to Hugging Face hub.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
path (str): Local path to the model.
|
|
|
|
upload_repo (str): Name of the HF repo to upload to.
|
|
|
|
hf_path (str): Path to the original Hugging Face model.
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
|
|
|
|
from huggingface_hub import HfApi, ModelCard, logging
|
|
|
|
|
2024-03-20 10:42:03 +08:00
|
|
|
from . import __version__
|
|
|
|
|
2024-01-24 00:44:37 +08:00
|
|
|
card = ModelCard.load(hf_path)
|
|
|
|
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
2024-03-01 14:23:01 +08:00
|
|
|
card.text = dedent(
|
|
|
|
f"""
|
|
|
|
# {upload_repo}
|
2024-05-03 12:22:04 +08:00
|
|
|
|
|
|
|
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
|
|
|
|
2024-03-01 14:23:01 +08:00
|
|
|
## Use with mlx
|
|
|
|
|
|
|
|
```bash
|
|
|
|
pip install mlx-lm
|
|
|
|
```
|
|
|
|
|
|
|
|
```python
|
|
|
|
from mlx_lm import load, generate
|
|
|
|
|
|
|
|
model, tokenizer = load("{upload_repo}")
|
|
|
|
response = generate(model, tokenizer, prompt="hello", verbose=True)
|
|
|
|
```
|
|
|
|
"""
|
|
|
|
)
|
2024-01-24 00:44:37 +08:00
|
|
|
card.save(os.path.join(path, "README.md"))
|
|
|
|
|
|
|
|
logging.set_verbosity_info()
|
|
|
|
|
|
|
|
api = HfApi()
|
|
|
|
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
|
|
|
api.upload_folder(
|
|
|
|
folder_path=path,
|
|
|
|
repo_id=upload_repo,
|
|
|
|
repo_type="model",
|
2024-04-29 10:07:17 +08:00
|
|
|
multi_commits=True,
|
|
|
|
multi_commits_verbose=True,
|
2024-01-24 00:44:37 +08:00
|
|
|
)
|
2024-03-06 13:51:31 +08:00
|
|
|
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
2024-01-24 00:44:37 +08:00
|
|
|
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
def save_weights(
|
|
|
|
save_path: Union[str, Path],
|
|
|
|
weights: Dict[str, Any],
|
|
|
|
*,
|
|
|
|
donate_weights: bool = False,
|
|
|
|
) -> None:
|
2024-01-24 00:44:37 +08:00
|
|
|
"""Save model weights into specified directory."""
|
|
|
|
if isinstance(save_path, str):
|
|
|
|
save_path = Path(save_path)
|
|
|
|
save_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
shards = make_shards(weights)
|
|
|
|
shards_count = len(shards)
|
|
|
|
shard_file_format = (
|
|
|
|
"model-{:05d}-of-{:05d}.safetensors"
|
|
|
|
if shards_count > 1
|
|
|
|
else "model.safetensors"
|
|
|
|
)
|
|
|
|
|
2024-02-06 21:32:15 +08:00
|
|
|
total_size = sum(v.nbytes for v in weights.values())
|
|
|
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
|
|
|
|
2024-02-21 05:36:55 +08:00
|
|
|
# Write the weights and make sure no references are kept other than the
|
|
|
|
# necessary ones
|
|
|
|
if donate_weights:
|
|
|
|
weights.clear()
|
2024-03-01 01:40:04 +08:00
|
|
|
del weights
|
2024-02-21 05:36:55 +08:00
|
|
|
|
|
|
|
for i in range(len(shards)):
|
|
|
|
shard = shards[i]
|
|
|
|
shards[i] = None
|
2024-01-24 00:44:37 +08:00
|
|
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
2024-02-06 21:32:15 +08:00
|
|
|
shard_path = save_path / shard_name
|
|
|
|
|
2024-02-28 23:29:00 +08:00
|
|
|
mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})
|
2024-02-06 21:32:15 +08:00
|
|
|
|
|
|
|
for weight_name in shard.keys():
|
|
|
|
index_data["weight_map"][weight_name] = shard_name
|
2024-02-21 05:36:55 +08:00
|
|
|
del shard
|
2024-02-06 21:32:15 +08:00
|
|
|
|
|
|
|
index_data["weight_map"] = {
|
|
|
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
|
|
|
}
|
|
|
|
|
|
|
|
with open(save_path / "model.safetensors.index.json", "w") as f:
|
|
|
|
json.dump(
|
|
|
|
index_data,
|
|
|
|
f,
|
|
|
|
indent=4,
|
|
|
|
)
|
2024-02-28 00:47:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
def quantize_model(
|
|
|
|
model: nn.Module, config: dict, q_group_size: int, q_bits: int
|
|
|
|
) -> Tuple:
|
|
|
|
"""
|
|
|
|
Applies quantization to the model weights.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (nn.Module): The model to be quantized.
|
|
|
|
config (dict): Model configuration.
|
|
|
|
q_group_size (int): Group size for quantization.
|
|
|
|
q_bits (int): Bits per weight for quantization.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple: Tuple containing quantized weights and config.
|
|
|
|
"""
|
|
|
|
quantized_config = copy.deepcopy(config)
|
2024-04-19 09:16:10 +08:00
|
|
|
nn.quantize(model, q_group_size, q_bits)
|
2024-02-28 00:47:56 +08:00
|
|
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
|
|
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
|
|
|
|
|
|
|
return quantized_weights, quantized_config
|
|
|
|
|
|
|
|
|
2024-03-14 21:36:05 +08:00
|
|
|
def save_config(
|
|
|
|
config: dict,
|
|
|
|
config_path: Union[str, Path],
|
|
|
|
) -> None:
|
|
|
|
"""Save the model configuration to the ``config_path``.
|
|
|
|
|
|
|
|
The final configuration will be sorted before saving for better readability.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
config (dict): The model configuration.
|
|
|
|
config_path (Union[str, Path]): Model configuration file path.
|
|
|
|
"""
|
|
|
|
# Clean unused keys
|
|
|
|
config.pop("_name_or_path", None)
|
|
|
|
|
|
|
|
# sort the config for better readability
|
|
|
|
config = dict(sorted(config.items()))
|
|
|
|
|
|
|
|
# write the updated config to the config_path (if provided)
|
|
|
|
with open(config_path, "w") as fid:
|
|
|
|
json.dump(config, fid, indent=4)
|
|
|
|
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
def convert(
|
|
|
|
hf_path: str,
|
|
|
|
mlx_path: str = "mlx_model",
|
|
|
|
quantize: bool = False,
|
|
|
|
q_group_size: int = 64,
|
|
|
|
q_bits: int = 4,
|
|
|
|
dtype: str = "float16",
|
|
|
|
upload_repo: str = None,
|
2024-03-02 22:28:26 +08:00
|
|
|
revision: Optional[str] = None,
|
2024-03-20 10:50:08 +08:00
|
|
|
dequantize: bool = False,
|
2024-02-28 00:47:56 +08:00
|
|
|
):
|
|
|
|
print("[INFO] Loading")
|
2024-03-02 22:28:26 +08:00
|
|
|
model_path = get_model_path(hf_path, revision=revision)
|
2024-02-28 00:47:56 +08:00
|
|
|
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
|
|
|
|
|
|
|
|
weights = dict(tree_flatten(model.parameters()))
|
|
|
|
dtype = mx.float16 if quantize else getattr(mx, dtype)
|
|
|
|
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
|
|
|
|
2024-03-20 10:50:08 +08:00
|
|
|
if quantize and dequantize:
|
|
|
|
raise ValueError("Choose either quantize or dequantize, not both.")
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
if quantize:
|
|
|
|
print("[INFO] Quantizing")
|
|
|
|
model.load_weights(list(weights.items()))
|
|
|
|
weights, config = quantize_model(model, config, q_group_size, q_bits)
|
|
|
|
|
2024-03-20 10:50:08 +08:00
|
|
|
if dequantize:
|
|
|
|
print("[INFO] Dequantizing")
|
|
|
|
model = dequantize_model(model)
|
|
|
|
weights = dict(tree_flatten(model.parameters()))
|
|
|
|
|
2024-02-28 00:47:56 +08:00
|
|
|
if isinstance(mlx_path, str):
|
|
|
|
mlx_path = Path(mlx_path)
|
|
|
|
|
|
|
|
del model
|
|
|
|
save_weights(mlx_path, weights, donate_weights=True)
|
|
|
|
|
|
|
|
py_files = glob.glob(str(model_path / "*.py"))
|
|
|
|
for file in py_files:
|
|
|
|
shutil.copy(file, mlx_path)
|
|
|
|
|
|
|
|
tokenizer.save_pretrained(mlx_path)
|
|
|
|
|
2024-03-14 21:36:05 +08:00
|
|
|
save_config(config, config_path=mlx_path / "config.json")
|
2024-02-28 00:47:56 +08:00
|
|
|
|
|
|
|
if upload_repo is not None:
|
|
|
|
upload_to_hub(mlx_path, upload_repo, hf_path)
|