mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 03:01:34 +08:00
Merge branch 'ml-explore:main' into adding-support-for-OLMoE
This commit is contained in:
commit
499a9f0758
73
llms/mlx_lm/examples/tool_use.py
Normal file
73
llms/mlx_lm/examples/tool_use.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from mlx_lm import generate, load
|
||||||
|
from mlx_lm.models.cache import make_prompt_cache
|
||||||
|
|
||||||
|
# Specify the checkpoint
|
||||||
|
checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit"
|
||||||
|
|
||||||
|
# Load the corresponding model and tokenizer
|
||||||
|
model, tokenizer = load(path_or_hf_repo=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
# An example tool, make sure to include a docstring and type hints
|
||||||
|
def multiply(a: float, b: float):
|
||||||
|
"""
|
||||||
|
A function that multiplies two numbers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: The first number to multiply
|
||||||
|
b: The second number to multiply
|
||||||
|
"""
|
||||||
|
return a * b
|
||||||
|
|
||||||
|
|
||||||
|
tools = {"multiply": multiply}
|
||||||
|
|
||||||
|
# Specify the prompt and conversation history
|
||||||
|
prompt = "Multiply 12234585 and 48838483920."
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, tools=list(tools.values())
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_cache = make_prompt_cache(model)
|
||||||
|
|
||||||
|
# Generate the initial tool call:
|
||||||
|
response = generate(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=2048,
|
||||||
|
verbose=True,
|
||||||
|
prompt_cache=prompt_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the tool call:
|
||||||
|
# (Note, the tool call format is model specific)
|
||||||
|
tool_open = "<tool_call>"
|
||||||
|
tool_close = "</tool_call>"
|
||||||
|
start_tool = response.find(tool_open) + len(tool_open)
|
||||||
|
end_tool = response.find(tool_close)
|
||||||
|
tool_call = json.loads(response[start_tool:end_tool].strip())
|
||||||
|
tool_result = tools[tool_call["name"]](**tool_call["arguments"])
|
||||||
|
|
||||||
|
# Put the tool result in the prompt
|
||||||
|
messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}]
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate the final response:
|
||||||
|
response = generate(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=2048,
|
||||||
|
verbose=True,
|
||||||
|
prompt_cache=prompt_cache,
|
||||||
|
)
|
@ -152,7 +152,7 @@ def setup_arg_parser():
|
|||||||
"--num-draft-tokens",
|
"--num-draft-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
help="Number of tokens to draft when using speculative decoding.",
|
help="Number of tokens to draft when using speculative decoding.",
|
||||||
default=2,
|
default=3,
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -33,13 +33,13 @@ def create_causal_mask(
|
|||||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||||
linds = linds[:, None]
|
linds = linds[:, None]
|
||||||
rinds = rinds[None]
|
rinds = rinds[None]
|
||||||
mask = linds < rinds
|
mask = linds >= rinds
|
||||||
if window_size is not None:
|
if window_size is not None:
|
||||||
mask = mask | (linds > rinds + window_size)
|
mask = mask & (linds <= rinds + window_size)
|
||||||
if lengths is not None:
|
if lengths is not None:
|
||||||
lengths = lengths[:, None, None, None]
|
lengths = lengths[:, None, None, None]
|
||||||
mask = mask | (rinds >= lengths)
|
mask = mask & (rinds < lengths)
|
||||||
return mask * -1e9
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||||
@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
|||||||
else:
|
else:
|
||||||
offset = c.offset
|
offset = c.offset
|
||||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||||
mask = mask.astype(h.dtype)
|
|
||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
return mask
|
return mask
|
||||||
|
Loading…
Reference in New Issue
Block a user