2024-10-08 11:45:51 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import json
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
|
2024-11-24 03:47:06 +08:00
|
|
|
from .models.cache import make_prompt_cache
|
|
|
|
from .sample_utils import make_sampler
|
2024-10-08 11:45:51 +08:00
|
|
|
from .utils import load, stream_generate
|
|
|
|
|
|
|
|
DEFAULT_TEMP = 0.0
|
|
|
|
DEFAULT_TOP_P = 1.0
|
2025-03-06 22:49:35 +08:00
|
|
|
DEFAULT_SEED = None
|
2024-11-04 22:06:34 +08:00
|
|
|
DEFAULT_MAX_TOKENS = 256
|
2024-10-08 11:45:51 +08:00
|
|
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
|
|
|
|
|
|
|
|
|
|
|
def setup_arg_parser():
|
|
|
|
"""Set up and return the argument parser."""
|
|
|
|
parser = argparse.ArgumentParser(description="Chat with an LLM")
|
|
|
|
parser.add_argument(
|
|
|
|
"--model",
|
|
|
|
type=str,
|
|
|
|
help="The path to the local model directory or Hugging Face repo.",
|
|
|
|
default=DEFAULT_MODEL,
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--adapter-path",
|
|
|
|
type=str,
|
|
|
|
help="Optional path for the trained adapter weights and config.",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
|
|
|
)
|
2025-03-06 22:49:35 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--seed",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_SEED,
|
|
|
|
help="PRNG seed",
|
|
|
|
)
|
2024-10-08 11:45:51 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--max-kv-size",
|
|
|
|
type=int,
|
|
|
|
help="Set the maximum key-value cache size",
|
|
|
|
default=None,
|
|
|
|
)
|
2024-11-04 22:06:34 +08:00
|
|
|
parser.add_argument(
|
|
|
|
"--max-tokens",
|
|
|
|
"-m",
|
|
|
|
type=int,
|
|
|
|
default=DEFAULT_MAX_TOKENS,
|
|
|
|
help="Maximum number of tokens to generate",
|
|
|
|
)
|
2024-10-08 11:45:51 +08:00
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
parser = setup_arg_parser()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
2025-03-06 22:49:35 +08:00
|
|
|
if args.seed is not None:
|
|
|
|
mx.random.seed(args.seed)
|
2024-10-08 11:45:51 +08:00
|
|
|
|
|
|
|
model, tokenizer = load(
|
|
|
|
args.model,
|
|
|
|
adapter_path=args.adapter_path,
|
|
|
|
tokenizer_config={"trust_remote_code": True},
|
|
|
|
)
|
|
|
|
|
2025-03-01 03:33:18 +08:00
|
|
|
def print_help():
|
|
|
|
print("The command list:")
|
|
|
|
print("- 'q' to exit")
|
|
|
|
print("- 'r' to reset the chat")
|
|
|
|
print("- 'h' to display these commands")
|
|
|
|
|
|
|
|
print(f"[INFO] Starting chat session with {args.model}.")
|
|
|
|
print_help()
|
2024-10-08 11:45:51 +08:00
|
|
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
|
|
|
while True:
|
|
|
|
query = input(">> ")
|
|
|
|
if query == "q":
|
|
|
|
break
|
2025-03-01 03:33:18 +08:00
|
|
|
if query == "r":
|
|
|
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
|
|
|
continue
|
|
|
|
if query == "h":
|
|
|
|
print_help()
|
|
|
|
continue
|
2024-10-08 11:45:51 +08:00
|
|
|
messages = [{"role": "user", "content": query}]
|
2025-01-04 02:50:59 +08:00
|
|
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
2024-11-24 03:47:06 +08:00
|
|
|
for response in stream_generate(
|
2024-10-08 11:45:51 +08:00
|
|
|
model,
|
|
|
|
tokenizer,
|
|
|
|
prompt,
|
2024-12-11 03:26:04 +08:00
|
|
|
max_tokens=args.max_tokens,
|
2024-11-24 03:47:06 +08:00
|
|
|
sampler=make_sampler(args.temp, args.top_p),
|
2024-10-08 11:45:51 +08:00
|
|
|
prompt_cache=prompt_cache,
|
|
|
|
):
|
2024-11-24 03:47:06 +08:00
|
|
|
print(response.text, flush=True, end="")
|
2024-10-08 11:45:51 +08:00
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|