mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Merge branch 'main' into adding-dpo-training
This commit is contained in:
@@ -20,6 +20,15 @@ training_mode: normal
|
||||
# reference_model_path: "mlx_model"
|
||||
# train_bias_only: False
|
||||
|
||||
# The Optimizer with its possible inputs
|
||||
optimizer: adamw
|
||||
# optimizer_config:
|
||||
# adamw:
|
||||
# betas: [0.9, 0.98]
|
||||
# eps: 1e-6
|
||||
# weight_decay: 0.05
|
||||
# bias_correction: true
|
||||
|
||||
# Directory with {train, valid, test}.jsonl files
|
||||
data: "/path/to/training/data"
|
||||
|
||||
|
||||
73
llms/mlx_lm/examples/tool_use.py
Normal file
73
llms/mlx_lm/examples/tool_use.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import json
|
||||
|
||||
from mlx_lm import generate, load
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
|
||||
# Specify the checkpoint
|
||||
checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit"
|
||||
|
||||
# Load the corresponding model and tokenizer
|
||||
model, tokenizer = load(path_or_hf_repo=checkpoint)
|
||||
|
||||
|
||||
# An example tool, make sure to include a docstring and type hints
|
||||
def multiply(a: float, b: float):
|
||||
"""
|
||||
A function that multiplies two numbers
|
||||
|
||||
Args:
|
||||
a: The first number to multiply
|
||||
b: The second number to multiply
|
||||
"""
|
||||
return a * b
|
||||
|
||||
|
||||
tools = {"multiply": multiply}
|
||||
|
||||
# Specify the prompt and conversation history
|
||||
prompt = "Multiply 12234585 and 48838483920."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tools=list(tools.values())
|
||||
)
|
||||
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
|
||||
# Generate the initial tool call:
|
||||
response = generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=2048,
|
||||
verbose=True,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
|
||||
# Parse the tool call:
|
||||
# (Note, the tool call format is model specific)
|
||||
tool_open = "<tool_call>"
|
||||
tool_close = "</tool_call>"
|
||||
start_tool = response.find(tool_open) + len(tool_open)
|
||||
end_tool = response.find(tool_close)
|
||||
tool_call = json.loads(response[start_tool:end_tool].strip())
|
||||
tool_result = tools[tool_call["name"]](**tool_call["arguments"])
|
||||
|
||||
# Put the tool result in the prompt
|
||||
messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
# Generate the final response:
|
||||
response = generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=2048,
|
||||
verbose=True,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
Reference in New Issue
Block a user