From 65aa2ec84918d4438a73d7504bae2f8e9f0d396b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Mar 2025 12:47:32 -0800 Subject: [PATCH 1/2] use a bool mask for attention (#1319) --- llms/mlx_lm/generate.py | 2 +- llms/mlx_lm/models/base.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index e40332dd..bd11dcf0 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -152,7 +152,7 @@ def setup_arg_parser(): "--num-draft-tokens", type=int, help="Number of tokens to draft when using speculative decoding.", - default=2, + default=3, ) return parser diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index ad7a4a65..8b40effb 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -33,13 +33,13 @@ def create_causal_mask( linds = mx.arange(offset, offset + N) if offset else rinds linds = linds[:, None] rinds = rinds[None] - mask = linds < rinds + mask = linds >= rinds if window_size is not None: - mask = mask | (linds > rinds + window_size) + mask = mask & (linds <= rinds + window_size) if lengths is not None: lengths = lengths[:, None, None, None] - mask = mask | (rinds >= lengths) - return mask * -1e9 + mask = mask & (rinds < lengths) + return mask def create_attention_mask(h: mx.array, cache: Optional[Any] = None): @@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: offset = c.offset mask = create_causal_mask(T, offset, window_size=window_size) - mask = mask.astype(h.dtype) else: mask = None return mask From f621218ff5284306c0f78ea4a34cd22c033e4b9d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Mar 2025 13:53:20 -0800 Subject: [PATCH 2/2] Tool use example (#1316) * tool use example * nits --- llms/mlx_lm/examples/tool_use.py | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 llms/mlx_lm/examples/tool_use.py diff --git a/llms/mlx_lm/examples/tool_use.py b/llms/mlx_lm/examples/tool_use.py new file mode 100644 index 00000000..624b9e5b --- /dev/null +++ b/llms/mlx_lm/examples/tool_use.py @@ -0,0 +1,73 @@ +# Copyright © 2025 Apple Inc. + +import json + +from mlx_lm import generate, load +from mlx_lm.models.cache import make_prompt_cache + +# Specify the checkpoint +checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit" + +# Load the corresponding model and tokenizer +model, tokenizer = load(path_or_hf_repo=checkpoint) + + +# An example tool, make sure to include a docstring and type hints +def multiply(a: float, b: float): + """ + A function that multiplies two numbers + + Args: + a: The first number to multiply + b: The second number to multiply + """ + return a * b + + +tools = {"multiply": multiply} + +# Specify the prompt and conversation history +prompt = "Multiply 12234585 and 48838483920." +messages = [{"role": "user", "content": prompt}] + +prompt = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tools=list(tools.values()) +) + +prompt_cache = make_prompt_cache(model) + +# Generate the initial tool call: +response = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=2048, + verbose=True, + prompt_cache=prompt_cache, +) + +# Parse the tool call: +# (Note, the tool call format is model specific) +tool_open = "" +tool_close = "" +start_tool = response.find(tool_open) + len(tool_open) +end_tool = response.find(tool_close) +tool_call = json.loads(response[start_tool:end_tool].strip()) +tool_result = tools[tool_call["name"]](**tool_call["arguments"]) + +# Put the tool result in the prompt +messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}] +prompt = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, +) + +# Generate the final response: +response = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=2048, + verbose=True, + prompt_cache=prompt_cache, +)