mlx-examples/llms/mlx_lm/examples/chat.py
Awni Hannun 0f135396ae
Generation refactor: part 2 (#1099)
* unify with stream_generate

* fixes

* nit

* some cleanup, warnings, tests

* fix test + faster min p + test

* version
2024-11-23 11:47:06 -08:00

53 lines
1.2 KiB
Python

# Copyright © 2024 Apple Inc.
"""
An example of a multi-turn chat with prompt caching.
"""
from mlx_lm import generate, load
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
# Make the initial prompt cache for the model
prompt_cache = make_prompt_cache(model)
# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)
# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
prompt_cache=prompt_cache,
)
# Save the prompt cache to disk to reuse it at a later time
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
# Load the prompt cache from disk
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")