mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Tweaks to run dspy-produced calls to the server, with gemma template. (#810)
* Tweaks to run dspy-produced calls to the server, with gemma template. following comment https://github.com/stanfordnlp/dspy/issues/385#issuecomment-1998939936 can try it out with: ```sh python -m server --model mlx-community/gemma-1.1-7b-it-4bit --port 1143 ``` modulo patching the relative imports in server.py ``` -from .tokenizer_utils import TokenizerWrapper -from .utils import generate_step, load +from mlx_lm.tokenizer_utils import TokenizerWrapper +from mlx_lm.utils import generate_step, load ``` and then, ont the dspy side: ```python import dspy lm = dspy.OpenAI(model_type="chat", api_base="http://localhost:11434/v1/", api_key="not_needed", max_tokens=250) lm("hello") ``` * simpler way to validate float or int * remove logic that works around incompatible templates, too gemma specific * tweak messages for common denominator * use generate.py workaround for DBXR * put behind flag * oops * Solution to chat template issue: pass in a custom template! The template should likely adhere to the OpenAI chat model. Here is such a template for Gemma. --chat-template "{{ bos_token }}{% set extra_system = '' %}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{% if role == 'system' %}{% set extra_system = extra_system + message['content'] %}{% else %}{% if role == 'user' and extra_system %}{% set message_system = 'System: ' + extra_system %}{% else %}{% set message_system = '' %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message_system + message['content'] | trim + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}" * remove convoluted solution * Tweak for when None is provided explicitly, and must be set to [] too. For example, the outlines library provides None explicitly. * style --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
6da07fb1b0
commit
3cc58e17fb
@ -140,7 +140,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
self.validate_model_parameters()
|
self.validate_model_parameters()
|
||||||
|
|
||||||
# Get stop id sequences, if provided
|
# Get stop id sequences, if provided
|
||||||
stop_words = self.body.get("stop", [])
|
stop_words = self.body.get("stop")
|
||||||
|
stop_words = stop_words or []
|
||||||
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
|
||||||
stop_id_sequences = [
|
stop_id_sequences = [
|
||||||
self.tokenizer.encode(stop_word, add_special_tokens=False)
|
self.tokenizer.encode(stop_word, add_special_tokens=False)
|
||||||
@ -171,14 +172,14 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
|
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
|
||||||
raise ValueError("max_tokens must be a non-negative integer")
|
raise ValueError("max_tokens must be a non-negative integer")
|
||||||
|
|
||||||
if not isinstance(self.temperature, float) or self.temperature < 0:
|
if not isinstance(self.temperature, (float, int)) or self.temperature < 0:
|
||||||
raise ValueError("temperature must be a non-negative float")
|
raise ValueError("temperature must be a non-negative float")
|
||||||
|
|
||||||
if not isinstance(self.top_p, float) or self.top_p < 0 or self.top_p > 1:
|
if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1:
|
||||||
raise ValueError("top_p must be a float between 0 and 1")
|
raise ValueError("top_p must be a float between 0 and 1")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not isinstance(self.repetition_penalty, float)
|
not isinstance(self.repetition_penalty, (float, int))
|
||||||
or self.repetition_penalty < 0
|
or self.repetition_penalty < 0
|
||||||
):
|
):
|
||||||
raise ValueError("repetition_penalty must be a non-negative float")
|
raise ValueError("repetition_penalty must be a non-negative float")
|
||||||
@ -527,6 +528,18 @@ def main():
|
|||||||
help="Set the MLX cache limit in GB",
|
help="Set the MLX cache limit in GB",
|
||||||
required=False,
|
required=False,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat-template",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Specify a chat template for the tokenizer",
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-default-chat-template",
|
||||||
|
action="store_true",
|
||||||
|
help="Use the default chat template",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -540,10 +553,17 @@ def main():
|
|||||||
|
|
||||||
# Building tokenizer_config
|
# Building tokenizer_config
|
||||||
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
|
||||||
|
if args.chat_template:
|
||||||
|
tokenizer_config["chat_template"] = args.chat_template
|
||||||
|
|
||||||
model, tokenizer = load(
|
model, tokenizer = load(
|
||||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.use_default_chat_template:
|
||||||
|
if tokenizer.chat_template is None:
|
||||||
|
tokenizer.chat_template = tokenizer.default_chat_template
|
||||||
|
|
||||||
run(args.host, args.port, model, tokenizer)
|
run(args.host, args.port, model, tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user