Merge branch 'ml-explore:main' into adding-support-for-mamba

This commit is contained in:
Gökdeniz Gülmez
2024-09-04 20:23:19 +02:00
committed by GitHub
8 changed files with 33 additions and 8 deletions

View File

@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.
from ._version import __version__
from .utils import convert, generate, load, stream_generate
from .version import __version__

View File

@@ -2,6 +2,7 @@
import argparse
import json
import sys
import mlx.core as mx
@@ -14,6 +15,10 @@ DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
def str2bool(string):
return string.lower() not in ["false", "f"]
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="LLM inference script")
@@ -39,7 +44,9 @@ def setup_arg_parser():
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
"--prompt",
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
@@ -65,6 +72,12 @@ def setup_arg_parser():
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--verbose",
type=str2bool,
default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
)
parser.add_argument(
"--colorize",
action="store_true",
@@ -178,7 +191,12 @@ def main():
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
messages = [
{
"role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
}
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
@@ -195,6 +213,8 @@ def main():
else:
prompt = args.prompt
if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
# Determine the max kv size from the kv cache or passed arguments
@@ -203,18 +223,20 @@ def main():
max_kv_size = metadata["max_kv_size"]
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
generate(
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
verbose=True,
verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=max_kv_size,
cache_history=cache_history,
)
if not args.verbose:
print(response)
if __name__ == "__main__":

View File

@@ -560,6 +560,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.data.base_model = hf_path
card.text = dedent(
f"""
# {upload_repo}
@@ -666,6 +667,8 @@ def quantize_model(
quantized_config = copy.deepcopy(config)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
# support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config

View File

@@ -10,7 +10,7 @@ with open(package_dir / "requirements.txt") as fid:
requirements = [l.strip() for l in fid.readlines()]
sys.path.append(str(package_dir))
from version import __version__
from _version import __version__
setup(
name="mlx-lm",

View File

@@ -1,5 +1,5 @@
# Copyright © 2023-2024 Apple Inc.
from . import audio, decoding, load_models
from ._version import __version__
from .transcribe import transcribe
from .version import __version__

View File

@@ -12,7 +12,7 @@ with open(package_dir / "requirements.txt") as fid:
sys.path.append(str(package_dir))
from version import __version__
from _version import __version__
setup(
name="mlx-whisper",