Merge branch 'ml-explore:main' into adding-support-for-OLMoE

This commit is contained in:
Gökdeniz Gülmez 2025-03-05 18:38:43 +01:00 committed by GitHub
commit 499a9f0758
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 78 additions and 6 deletions

View File

@ -0,0 +1,73 @@
# Copyright © 2025 Apple Inc.
import json
from mlx_lm import generate, load
from mlx_lm.models.cache import make_prompt_cache
# Specify the checkpoint
checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit"
# Load the corresponding model and tokenizer
model, tokenizer = load(path_or_hf_repo=checkpoint)
# An example tool, make sure to include a docstring and type hints
def multiply(a: float, b: float):
"""
A function that multiplies two numbers
Args:
a: The first number to multiply
b: The second number to multiply
"""
return a * b
tools = {"multiply": multiply}
# Specify the prompt and conversation history
prompt = "Multiply 12234585 and 48838483920."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tools=list(tools.values())
)
prompt_cache = make_prompt_cache(model)
# Generate the initial tool call:
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=2048,
verbose=True,
prompt_cache=prompt_cache,
)
# Parse the tool call:
# (Note, the tool call format is model specific)
tool_open = "<tool_call>"
tool_close = "</tool_call>"
start_tool = response.find(tool_open) + len(tool_open)
end_tool = response.find(tool_close)
tool_call = json.loads(response[start_tool:end_tool].strip())
tool_result = tools[tool_call["name"]](**tool_call["arguments"])
# Put the tool result in the prompt
messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}]
prompt = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
)
# Generate the final response:
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=2048,
verbose=True,
prompt_cache=prompt_cache,
)

View File

@ -152,7 +152,7 @@ def setup_arg_parser():
"--num-draft-tokens", "--num-draft-tokens",
type=int, type=int,
help="Number of tokens to draft when using speculative decoding.", help="Number of tokens to draft when using speculative decoding.",
default=2, default=3,
) )
return parser return parser

View File

@ -33,13 +33,13 @@ def create_causal_mask(
linds = mx.arange(offset, offset + N) if offset else rinds linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None] linds = linds[:, None]
rinds = rinds[None] rinds = rinds[None]
mask = linds < rinds mask = linds >= rinds
if window_size is not None: if window_size is not None:
mask = mask | (linds > rinds + window_size) mask = mask & (linds <= rinds + window_size)
if lengths is not None: if lengths is not None:
lengths = lengths[:, None, None, None] lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths) mask = mask & (rinds < lengths)
return mask * -1e9 return mask
def create_attention_mask(h: mx.array, cache: Optional[Any] = None): def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
else: else:
offset = c.offset offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size) mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else: else:
mask = None mask = None
return mask return mask