mlx-examples/llms/mlx_lm/chat.py
2025-02-06 07:28:59 -08:00

123 lines
3.3 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def share_message(world, prompt):
if world.size() == 1:
return prompt
if world.rank() == 0:
size = mx.array([len(prompt)])
else:
size = mx.array([0])
size = mx.distributed.all_sum(size, stream=mx.cpu).item()
if size == 0:
return []
if world.rank() == 0:
prompt = mx.array(prompt)
else:
prompt = mx.array([0]*len(prompt))
return mx.distributed.all_sum(size, stream=mx.cpu).tolist()
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser
def main():
world = mx.distributed.init()
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
sequential_load=mx.distributed.init().size() > 1,
)
print(f"Node {world.rank()} of {world.size()}", flush=True)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True)
world.barrier()
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
if world.rank() == 0:
query = input(">> ")
if query == "q":
prompt = []
else:
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
prompt = share_message(world, prompt)
if len(prompt) == 0:
break
for response in stream_generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache,
):
if world.rank() == 0:
print(response, flush=True, end="")
if world.rank() == 0:
print()
if __name__ == "__main__":
main()