mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-09 21:42:43 +08:00
Merge branch 'ml-explore:main' into adding-support-for-mamba
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from ._version import __version__
|
||||||
from .utils import convert, generate, load, stream_generate
|
from .utils import convert, generate, load, stream_generate
|
||||||
from .version import __version__
|
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
@@ -14,6 +15,10 @@ DEFAULT_TOP_P = 1.0
|
|||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
|
||||||
|
|
||||||
|
def str2bool(string):
|
||||||
|
return string.lower() not in ["false", "f"]
|
||||||
|
|
||||||
|
|
||||||
def setup_arg_parser():
|
def setup_arg_parser():
|
||||||
"""Set up and return the argument parser."""
|
"""Set up and return the argument parser."""
|
||||||
parser = argparse.ArgumentParser(description="LLM inference script")
|
parser = argparse.ArgumentParser(description="LLM inference script")
|
||||||
@@ -39,7 +44,9 @@ def setup_arg_parser():
|
|||||||
help="End of sequence token for tokenizer",
|
help="End of sequence token for tokenizer",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model"
|
"--prompt",
|
||||||
|
default=DEFAULT_PROMPT,
|
||||||
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-tokens",
|
"--max-tokens",
|
||||||
@@ -65,6 +72,12 @@ def setup_arg_parser():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use the default chat template",
|
help="Use the default chat template",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--colorize",
|
"--colorize",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -178,7 +191,12 @@ def main():
|
|||||||
hasattr(tokenizer, "apply_chat_template")
|
hasattr(tokenizer, "apply_chat_template")
|
||||||
and tokenizer.chat_template is not None
|
and tokenizer.chat_template is not None
|
||||||
):
|
):
|
||||||
messages = [{"role": "user", "content": args.prompt}]
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
|
||||||
|
}
|
||||||
|
]
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
@@ -195,6 +213,8 @@ def main():
|
|||||||
else:
|
else:
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
||||||
|
if args.colorize and not args.verbose:
|
||||||
|
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||||
formatter = colorprint_by_t0 if args.colorize else None
|
formatter = colorprint_by_t0 if args.colorize else None
|
||||||
|
|
||||||
# Determine the max kv size from the kv cache or passed arguments
|
# Determine the max kv size from the kv cache or passed arguments
|
||||||
@@ -203,18 +223,20 @@ def main():
|
|||||||
max_kv_size = metadata["max_kv_size"]
|
max_kv_size = metadata["max_kv_size"]
|
||||||
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
||||||
|
|
||||||
generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
prompt,
|
prompt,
|
||||||
args.max_tokens,
|
args.max_tokens,
|
||||||
verbose=True,
|
verbose=args.verbose,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
temp=args.temp,
|
temp=args.temp,
|
||||||
top_p=args.top_p,
|
top_p=args.top_p,
|
||||||
max_kv_size=max_kv_size,
|
max_kv_size=max_kv_size,
|
||||||
cache_history=cache_history,
|
cache_history=cache_history,
|
||||||
)
|
)
|
||||||
|
if not args.verbose:
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@@ -560,6 +560,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|||||||
|
|
||||||
card = ModelCard.load(hf_path)
|
card = ModelCard.load(hf_path)
|
||||||
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
||||||
|
card.data.base_model = hf_path
|
||||||
card.text = dedent(
|
card.text = dedent(
|
||||||
f"""
|
f"""
|
||||||
# {upload_repo}
|
# {upload_repo}
|
||||||
@@ -666,6 +667,8 @@ def quantize_model(
|
|||||||
quantized_config = copy.deepcopy(config)
|
quantized_config = copy.deepcopy(config)
|
||||||
nn.quantize(model, q_group_size, q_bits)
|
nn.quantize(model, q_group_size, q_bits)
|
||||||
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
|
||||||
|
# support hf model tree #957
|
||||||
|
quantized_config["quantization_config"] = quantized_config["quantization"]
|
||||||
quantized_weights = dict(tree_flatten(model.parameters()))
|
quantized_weights = dict(tree_flatten(model.parameters()))
|
||||||
|
|
||||||
return quantized_weights, quantized_config
|
return quantized_weights, quantized_config
|
||||||
|
@@ -10,7 +10,7 @@ with open(package_dir / "requirements.txt") as fid:
|
|||||||
requirements = [l.strip() for l in fid.readlines()]
|
requirements = [l.strip() for l in fid.readlines()]
|
||||||
|
|
||||||
sys.path.append(str(package_dir))
|
sys.path.append(str(package_dir))
|
||||||
from version import __version__
|
from _version import __version__
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx-lm",
|
name="mlx-lm",
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from . import audio, decoding, load_models
|
from . import audio, decoding, load_models
|
||||||
|
from ._version import __version__
|
||||||
from .transcribe import transcribe
|
from .transcribe import transcribe
|
||||||
from .version import __version__
|
|
||||||
|
@@ -12,7 +12,7 @@ with open(package_dir / "requirements.txt") as fid:
|
|||||||
|
|
||||||
sys.path.append(str(package_dir))
|
sys.path.append(str(package_dir))
|
||||||
|
|
||||||
from version import __version__
|
from _version import __version__
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="mlx-whisper",
|
name="mlx-whisper",
|
||||||
|
Reference in New Issue
Block a user