From b7f742ef56ee2b7b127e2c1390a4ab625dc044e4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 26 Feb 2025 20:32:36 +0100 Subject: [PATCH 01/18] Mixed quant recipes (#1300) * Mixed 3/6 and 2/6 recipes based on Alex Barron's * format / nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/convert.py | 27 ++++++++++++++++++++++++++- llms/mlx_lm/utils.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 9bac77a5..86a96447 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -1,8 +1,27 @@ # Copyright © 2023-2024 Apple Inc. import argparse +from enum import Enum -from .utils import convert +from .utils import convert, mixed_2_6, mixed_3_6 + + +class MixedQuants(Enum): + mixed_3_6 = "mixed_3_6" + mixed_2_6 = "mixed_2_6" + + @classmethod + def recipe_names(cls): + return [member.name for member in cls] + + +def quant_args(arg): + try: + return MixedQuants[arg].value + except KeyError: + raise argparse.ArgumentTypeError( + f"Invalid q-recipe {arg!r}. Choose from: {MixedQuants.recipe_names()}" + ) def configure_parser() -> argparse.ArgumentParser: @@ -29,6 +48,12 @@ def configure_parser() -> argparse.ArgumentParser: parser.add_argument( "--q-bits", help="Bits per weight for quantization.", type=int, default=4 ) + parser.add_argument( + "--quant-predicate", + help=f"Mixed-bit quantization recipe. Choices: {MixedQuants.recipe_names()}", + type=quant_args, + required=False, + ) parser.add_argument( "--dtype", help="Type to save the non-quantized parameters.", diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 2d760743..7dff0ee3 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -1015,6 +1015,46 @@ def save_config( json.dump(config, fid, indent=4) +def mixed_quant_predicate_builder( + low_bits: int = 4, high_bits: int = 4, group_size: int = 64 +) -> Callable[[str, nn.Module, dict], Union[bool, dict]]: + def mixed_quant_predicate( + path: str, + module: nn.Module, + config: dict, + ) -> Union[bool, dict]: + """Implements mixed quantization predicates with similar choices to, for example, llama.cpp's Q4_K_M. + Ref: https://github.com/ggerganov/llama.cpp/blob/917786f43d0f29b7c77a0c56767c0fa4df68b1c5/src/llama.cpp#L5265 + By Alex Barron: https://gist.github.com/barronalex/84addb8078be21969f1690c1454855f3 + """ + + if not hasattr(module, "to_quantized"): + return False + + index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0 + + num_layers = config["num_hidden_layers"] + use_more_bits = ( + index < num_layers // 8 + or index >= 7 * num_layers // 8 + or (index - num_layers // 8) % 3 == 2 + ) + if "v_proj" in path and use_more_bits: + return {"group_size": group_size, "bits": high_bits} + if "down_proj" in path and use_more_bits: + return {"group_size": group_size, "bits": high_bits} + if "lm_head" in path: + return {"group_size": group_size, "bits": high_bits} + + return {"group_size": group_size, "bits": low_bits} + + return mixed_quant_predicate + + +mixed_3_6 = mixed_quant_predicate_builder(low_bits=3) +mixed_2_6 = mixed_quant_predicate_builder(low_bits=2) + + def convert( hf_path: str, mlx_path: str = "mlx_model", From 56e60ad5a6692d4cee38701237de1bb37e19e994 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 26 Feb 2025 15:44:57 -0800 Subject: [PATCH 02/18] fix manage for new transformers (#1304) --- llms/mlx_lm/manage.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py index 9827f3dc..c06de6b3 100644 --- a/llms/mlx_lm/manage.py +++ b/llms/mlx_lm/manage.py @@ -2,7 +2,22 @@ import argparse from typing import List, Union from huggingface_hub import scan_cache_dir -from transformers.commands.user import tabulate + + +def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: + """ + Inspired by: + - stackoverflow.com/a/8356620/593036 + - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data + """ + col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] + row_format = ("{{:{}}} " * len(headers)).format(*col_widths) + lines = [] + lines.append(row_format.format(*headers)) + lines.append(row_format.format(*["-" * w for w in col_widths])) + for row in rows: + lines.append(row_format.format(*row)) + return "\n".join(lines) def ask_for_confirmation(message: str) -> bool: From 0f240a4c7e37332361d5d595499535d6ba7cb73b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 26 Feb 2025 15:46:16 -0800 Subject: [PATCH 03/18] Use max tokens from options in mlx_lm evaluate (#1302) --- llms/mlx_lm/evaluate.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index 2f35ade2..cd6de7ec 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -289,17 +289,15 @@ class MLXLM(LM): contexts, options = zip(*[req.args for req in requests]) # contrary to the doc the second element of the tuple contains # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} - keys = list(options[0].keys()) - assert "until" in keys - untils = [x["until"] for x in options] completions = [] - for context, until in tqdm(zip(contexts, untils), total=len(contexts)): + for context, opt in tqdm(zip(contexts, options), total=len(contexts)): + until = opt["until"] context = self.tokenizer.encode( context, add_special_tokens=not self.use_chat_template ) max_tokens = min( - self._max_tokens, + opt.get("max_gen_tokens", self._max_tokens), self.tokenizer.model_max_length - len(context), ) text = "" @@ -334,9 +332,9 @@ def main(): ) parser.add_argument( "--limit", - default=1.0, + default=100, help="Limit the number of examples per task.", - type=float, + type=int, ) parser.add_argument("--seed", type=int, default=123, help="Random seed.") parser.add_argument( From 00a73790702991075b80f9facf219ae397e1eb15 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 26 Feb 2025 16:21:54 -0800 Subject: [PATCH 04/18] Fixes for phi4 mini (#1305) --- llms/mlx_lm/models/phi3.py | 16 ++++++++++++---- llms/mlx_lm/models/su_rope.py | 6 ++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index d1c21e25..63e985de 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs): rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None + partial_rotary_factor: float = 1.0 max_position_embeddings: int = 131072 original_max_position_embeddings: int = 4096 + tie_word_embeddings: bool = False def __post_init__(self): if self.num_key_value_heads is None: @@ -59,9 +61,10 @@ class Attention(nn.Module): self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + rope_dim = int(head_dim * args.partial_rotary_factor) if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: self.rope = SuScaledRotaryEmbedding( - head_dim, + rope_dim, base=args.rope_theta, max_position_embeddings=args.max_position_embeddings, original_max_position_embeddings=args.original_max_position_embeddings, @@ -74,7 +77,7 @@ class Attention(nn.Module): assert isinstance(args.rope_scaling["factor"], float) rope_scale = 1 / args.rope_scaling["factor"] self.rope = nn.RoPE( - head_dim, + rope_dim, traditional=args.rope_traditional, base=args.rope_theta, scale=rope_scale, @@ -190,7 +193,8 @@ class Model(nn.Module): super().__init__() self.model_type = args.model_type self.model = Phi3Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.args = args def __call__( @@ -200,7 +204,11 @@ class Model(nn.Module): cache=None, ): out = self.model(inputs, mask, cache) - return self.lm_head(out) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out @property def layers(self): diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py index 9c414afd..6340c77b 100644 --- a/llms/mlx_lm/models/su_rope.py +++ b/llms/mlx_lm/models/su_rope.py @@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module): + math.log(max_position_embeddings / original_max_position_embeddings) / math.log(original_max_position_embeddings) ) + self.dim = dims def __call__(self, x, offset: int = 0): + x[..., : self.dim] = self.scale * x[..., : self.dim] return mx.fast.rope( - self.scale * x, - x.shape[-1], + x, + self.dim, traditional=False, base=None, scale=1.0, From eb7354963119df3b533cb63ace61a92c2dcb532d Mon Sep 17 00:00:00 2001 From: madroid Date: Thu, 27 Feb 2025 23:44:00 +0800 Subject: [PATCH 05/18] Generate: Support Prefill Response (#1299) * Generate: Support Prefill Prompt python -m mlx_lm.generate \ --model mlx-community/DeepSeek-R1-Distill-Qwen-1.5B-4bit \ --prompt "hello" \ --prefill-prompt "\n" * Generate: rename prefill-prompt to prefill-response * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/generate.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index d8f97e5e..e40332dd 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -60,6 +60,11 @@ def setup_arg_parser(): default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) + parser.add_argument( + "--prefill-response", + default=None, + help="Prefill response to be used for the chat template", + ) parser.add_argument( "--max-tokens", "-m", @@ -219,10 +224,14 @@ def main(): messages = [] messages.append({"role": "user", "content": prompt}) + has_prefill = args.prefill_response is not None + if has_prefill: + messages.append({"role": "assistant", "content": args.prefill_response}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, - add_generation_prompt=True, + continue_final_message=has_prefill, + add_generation_prompt=not has_prefill, **template_kwargs, ) @@ -233,7 +242,8 @@ def main(): test_prompt = tokenizer.apply_chat_template( messages, tokenize=False, - add_generation_prompt=True, + continue_final_message=has_prefill, + add_generation_prompt=not has_prefill, ) prompt = prompt[test_prompt.index("") :] prompt = tokenizer.encode(prompt, add_special_tokens=False) From b2108a0de694fb753ec46f91590f0f1e59494d63 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 28 Feb 2025 11:33:04 -0800 Subject: [PATCH 06/18] Allow mask prompt in config (#1314) --- llms/mlx_lm/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index def3b6dd..d32bfe6d 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -62,6 +62,7 @@ CONFIG_DEFAULTS = { "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + "mask_prompt": False, } @@ -99,7 +100,7 @@ def build_parser(): "--mask-prompt", action="store_true", help="Mask the prompt in the loss when training", - default=False, + default=None, ) parser.add_argument( From 845cd8c01e91fb50b7a5467db3efc8de4d2c29f6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 28 Feb 2025 11:33:18 -0800 Subject: [PATCH 07/18] support kimi + more options in chat mode (#1312) --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/chat.py | 15 +++++++- llms/mlx_lm/models/deepseek_v3.py | 61 ++++++++++++++++++------------- llms/mlx_lm/utils.py | 1 + 4 files changed, 52 insertions(+), 27 deletions(-) diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 89e6cd00..839089b6 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.21.5" +__version__ = "0.21.6" diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index e52ad10d..5c0b78db 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -65,12 +65,25 @@ def main(): tokenizer_config={"trust_remote_code": True}, ) - print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") + def print_help(): + print("The command list:") + print("- 'q' to exit") + print("- 'r' to reset the chat") + print("- 'h' to display these commands") + + print(f"[INFO] Starting chat session with {args.model}.") + print_help() prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: query = input(">> ") if query == "q": break + if query == "r": + prompt_cache = make_prompt_cache(model, args.max_kv_size) + continue + if query == "h": + print_help() + continue messages = [{"role": "user", "content": query}] prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) for response in stream_generate( diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index 47e17236..5cd40a0d 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -181,30 +181,37 @@ class DeepseekV3Attention(nn.Module): bias=config.attention_bias, ) - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.scale = self.scale * mscale * mscale + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale - rope_kwargs = { - key: self.config.rope_scaling[key] - for key in [ - "original_max_position_embeddings", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ] - if key in self.config.rope_scaling - } - self.rope = DeepseekV3YarnRotaryEmbedding( - dim=self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - **rope_kwargs, - ) + rope_kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rope = DeepseekV3YarnRotaryEmbedding( + dim=self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **rope_kwargs, + ) + else: + self.rope = nn.RoPE( + dims=self.qk_rope_head_dim, + base=self.rope_theta, + traditional=True, + ) def __call__( self, @@ -487,8 +494,12 @@ class Model(nn.Module): ] weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) - # Remove multi-token prediction layer - return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")} + # Remove multi-token prediction layer and any unused precomputed rotary freqs + return { + k: v + for k, v in weights.items() + if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k + } @property def layers(self): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 7dff0ee3..05fac92f 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -191,6 +191,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path "*.py", "tokenizer.model", "*.tiktoken", + "tiktoken.model", "*.txt", "*.jsonl", ], From 269faa5fa4cf8cdb2e1f4df3200db5a551318b43 Mon Sep 17 00:00:00 2001 From: Shunta Saito Date: Mon, 3 Mar 2025 23:12:02 +0900 Subject: [PATCH 08/18] Fix plamo2 model to use rms_norm (#1308) * Fix plamo2 model to use rms_norm and enable sliding window attention * Fix missing variable * Remove sliding window attention impl. cause it should be done by using RotatingKVCache * Remove unused imports --- llms/mlx_lm/models/plamo2.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index 1d8215dd..657fa02e 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn @@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs): mamba_enabled: bool = True intermediate_size: int = 13312 vocab_size: int = 32000 - max_position_embeddings: int = 10 * 1024 * 1024 class RMSNorm(nn.Module): @@ -53,6 +52,16 @@ class RMSNorm(nn.Module): ) +def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.astype(mx.float32) + variance = mx.power(hidden_states, 2).mean(-1, keepdims=True) + hidden_states = hidden_states * mx.rsqrt(variance + eps) + hidden_states = hidden_states.astype(input_dtype) + + return hidden_states + + def get_initial_dt_bias(num_heads: int) -> mx.array: dt_min = 0.001 dt_max = 0.1 @@ -220,8 +229,7 @@ def ssd_chunk_scan_combined( def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: - batch, seqlen, dim = x.shape - width = weight.shape[1] + _, seqlen, dim = x.shape state_len = conv_state.shape[-2] x = mx.concatenate([conv_state, x], axis=-2) conv_state = x[:, -state_len:] @@ -392,8 +400,8 @@ class Attention(nn.Module): k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) - q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] - k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] + q = _rms_norm(q, 1e-6) * self.q_weight[:, None] + k = _rms_norm(k, 1e-6) * self.k_weight[:, None] if cache is not None: q = self.rope(q, offset=cache.offset) @@ -556,7 +564,6 @@ class PlamoModel(nn.Module): class Model(nn.Module): - def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config @@ -567,7 +574,7 @@ class Model(nn.Module): if not config.tie_word_embeddings: self.lm_head: nn.Module = nn.Linear( - config.hidden_size, vocab_size, bias=False + config.hidden_size, self.vocab_size, bias=False ) def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: From 1bc3476a46d8d8e6bf7f2cb17570c2b7a26eafd0 Mon Sep 17 00:00:00 2001 From: Pierre-Louis <78484833+PierreLouisLetoquart@users.noreply.github.com> Date: Mon, 3 Mar 2025 09:12:33 -0500 Subject: [PATCH 09/18] chore(lora): Add real-time log buffering fix for nohup execution (#1311) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore(lora): Add real-time log buffering fix for nohup execution Disable Python stdout buffering to ensure logs appear in nohup.out in real-time instead of only after script completion. * chore(lora): remove python 3.7+ check * chore(lora): running pre-commit hook --------- Co-authored-by: Pierre-Louis Létoquart --- lora/lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lora/lora.py b/lora/lora.py index 723e783d..6f91ccca 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -3,6 +3,7 @@ import argparse import json import math +import sys import time from pathlib import Path @@ -14,6 +15,9 @@ import utils as lora_utils from mlx.utils import tree_flatten from models import LoRALinear +# Disable output buffering to see print statements in real-time +sys.stdout.reconfigure(line_buffering=True) + def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") From 65aa2ec84918d4438a73d7504bae2f8e9f0d396b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Mar 2025 12:47:32 -0800 Subject: [PATCH 10/18] use a bool mask for attention (#1319) --- llms/mlx_lm/generate.py | 2 +- llms/mlx_lm/models/base.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index e40332dd..bd11dcf0 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -152,7 +152,7 @@ def setup_arg_parser(): "--num-draft-tokens", type=int, help="Number of tokens to draft when using speculative decoding.", - default=2, + default=3, ) return parser diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index ad7a4a65..8b40effb 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -33,13 +33,13 @@ def create_causal_mask( linds = mx.arange(offset, offset + N) if offset else rinds linds = linds[:, None] rinds = rinds[None] - mask = linds < rinds + mask = linds >= rinds if window_size is not None: - mask = mask | (linds > rinds + window_size) + mask = mask & (linds <= rinds + window_size) if lengths is not None: lengths = lengths[:, None, None, None] - mask = mask | (rinds >= lengths) - return mask * -1e9 + mask = mask & (rinds < lengths) + return mask def create_attention_mask(h: mx.array, cache: Optional[Any] = None): @@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: offset = c.offset mask = create_causal_mask(T, offset, window_size=window_size) - mask = mask.astype(h.dtype) else: mask = None return mask From f621218ff5284306c0f78ea4a34cd22c033e4b9d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Mar 2025 13:53:20 -0800 Subject: [PATCH 11/18] Tool use example (#1316) * tool use example * nits --- llms/mlx_lm/examples/tool_use.py | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 llms/mlx_lm/examples/tool_use.py diff --git a/llms/mlx_lm/examples/tool_use.py b/llms/mlx_lm/examples/tool_use.py new file mode 100644 index 00000000..624b9e5b --- /dev/null +++ b/llms/mlx_lm/examples/tool_use.py @@ -0,0 +1,73 @@ +# Copyright © 2025 Apple Inc. + +import json + +from mlx_lm import generate, load +from mlx_lm.models.cache import make_prompt_cache + +# Specify the checkpoint +checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit" + +# Load the corresponding model and tokenizer +model, tokenizer = load(path_or_hf_repo=checkpoint) + + +# An example tool, make sure to include a docstring and type hints +def multiply(a: float, b: float): + """ + A function that multiplies two numbers + + Args: + a: The first number to multiply + b: The second number to multiply + """ + return a * b + + +tools = {"multiply": multiply} + +# Specify the prompt and conversation history +prompt = "Multiply 12234585 and 48838483920." +messages = [{"role": "user", "content": prompt}] + +prompt = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tools=list(tools.values()) +) + +prompt_cache = make_prompt_cache(model) + +# Generate the initial tool call: +response = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=2048, + verbose=True, + prompt_cache=prompt_cache, +) + +# Parse the tool call: +# (Note, the tool call format is model specific) +tool_open = "" +tool_close = "" +start_tool = response.find(tool_open) + len(tool_open) +end_tool = response.find(tool_close) +tool_call = json.loads(response[start_tool:end_tool].strip()) +tool_result = tools[tool_call["name"]](**tool_call["arguments"]) + +# Put the tool result in the prompt +messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}] +prompt = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, +) + +# Generate the final response: +response = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=2048, + verbose=True, + prompt_cache=prompt_cache, +) From e7267d30f83bc3f22ff6f0f8132ca0bcd9c38115 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 5 Mar 2025 13:33:15 -0800 Subject: [PATCH 12/18] Distributed support cifar (#1301) --- cifar/README.md | 14 ++++++++ cifar/dataset.py | 11 +++++- cifar/main.py | 91 +++++++++++++++++++++++++++++++----------------- 3 files changed, 84 insertions(+), 32 deletions(-) diff --git a/cifar/README.md b/cifar/README.md index 763e641d..2016200d 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM. At the time of writing, `mlx` doesn't have built-in learning rate schedules. We intend to update this example once these features are added. + +## Distributed training + +The example also supports distributed data parallel training. You can launch a +distributed training as follows: + +```shell +$ cat >hostfile.json +[ + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}, + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]} +] +$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20 +``` diff --git a/cifar/dataset.py b/cifar/dataset.py index 22b229f8..8967591e 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -1,3 +1,4 @@ +import mlx.core as mx import numpy as np from mlx.data.datasets import load_cifar10 @@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None): x = x.astype("float32") / 255.0 return (x - mean) / std + group = mx.distributed.init() + tr_iter = ( tr.shuffle() + .partition_if(group.size() > 1, group.size(), group.rank()) .to_stream() .image_random_h_flip("image", prob=0.5) .pad("image", 0, 4, 4, 0.0) @@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None): ) test = load_cifar10(root=root, train=False) - test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size) + test_iter = ( + test.to_stream() + .partition_if(group.size() > 1, group.size(), group.rank()) + .key_transform("image", normalize) + .batch(batch_size) + ) return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py index 378bc424..ac010636 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed") parser.add_argument("--cpu", action="store_true", help="use cpu only") +def print_zero(group, *args, **kwargs): + if group.rank() != 0: + return + flush = kwargs.pop("flush", True) + print(*args, **kwargs, flush=flush) + + def eval_fn(model, inp, tgt): return mx.mean(mx.argmax(model(inp), axis=1) == tgt) @@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch): acc = mx.mean(mx.argmax(output, axis=1) == tgt) return loss, acc - losses = [] - accs = [] - samples_per_sec = [] + world = mx.distributed.init() + losses = 0 + accuracies = 0 + samples_per_sec = 0 + count = 0 + + def average_stats(stats, count): + if world.size() == 1: + return [s / count for s in stats] + + with mx.stream(mx.cpu): + stats = mx.distributed.all_sum(mx.array(stats)) + count = mx.distributed.all_sum(count) + return (stats / count).tolist() state = [model.state, optimizer.state] @@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch): def step(inp, tgt): train_step_fn = nn.value_and_grad(model, train_step) (loss, acc), grads = train_step_fn(model, inp, tgt) + grads = nn.utils.average_gradients(grads) optimizer.update(model, grads) return loss, acc @@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch): y = mx.array(batch["label"]) tic = time.perf_counter() loss, acc = step(x, y) - mx.eval(state) + mx.eval(loss, acc, state) toc = time.perf_counter() - loss = loss.item() - acc = acc.item() - losses.append(loss) - accs.append(acc) - throughput = x.shape[0] / (toc - tic) - samples_per_sec.append(throughput) + losses += loss.item() + accuracies += acc.item() + samples_per_sec += x.shape[0] / (toc - tic) + count += 1 if batch_counter % 10 == 0: - print( + l, a, s = average_stats( + [losses, accuracies, world.size() * samples_per_sec], + count, + ) + print_zero( + world, " | ".join( ( f"Epoch {epoch:02d} [{batch_counter:03d}]", - f"Train loss {loss:.3f}", - f"Train acc {acc:.3f}", - f"Throughput: {throughput:.2f} images/second", + f"Train loss {l:.3f}", + f"Train acc {a:.3f}", + f"Throughput: {s:.2f} images/second", ) - ) + ), ) - mean_tr_loss = mx.mean(mx.array(losses)) - mean_tr_acc = mx.mean(mx.array(accs)) - samples_per_sec = mx.mean(mx.array(samples_per_sec)) - return mean_tr_loss, mean_tr_acc, samples_per_sec + return average_stats([losses, accuracies, world.size() * samples_per_sec], count) def test_epoch(model, test_iter, epoch): - accs = [] + accuracies = 0 + count = 0 for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["image"]) y = mx.array(batch["label"]) acc = eval_fn(model, x, y) - acc_value = acc.item() - accs.append(acc_value) - mean_acc = mx.mean(mx.array(accs)) - return mean_acc + accuracies += acc.item() + count += 1 + + with mx.stream(mx.cpu): + accuracies = mx.distributed.all_sum(accuracies) + count = mx.distributed.all_sum(count) + return (accuracies / count).item() def main(args): mx.random.seed(args.seed) + # Initialize the distributed group and report the nodes that showed up + world = mx.distributed.init() + if world.size() > 1: + print(f"Starting rank {world.rank()} of {world.size()}", flush=True) + model = getattr(resnet, args.arch)() - print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) + print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M") optimizer = optim.Adam(learning_rate=args.lr) train_data, test_data = get_cifar10(args.batch_size) for epoch in range(args.epochs): tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) - print( + print_zero( + world, " | ".join( ( f"Epoch: {epoch}", - f"avg. Train loss {tr_loss.item():.3f}", - f"avg. Train acc {tr_acc.item():.3f}", - f"Throughput: {throughput.item():.2f} images/sec", + f"avg. Train loss {tr_loss:.3f}", + f"avg. Train acc {tr_acc:.3f}", + f"Throughput: {throughput:.2f} images/sec", ) - ) + ), ) test_acc = test_epoch(model, test_data, epoch) - print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}") + print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}") train_data.reset() test_data.reset() From 56d2db23e1348f046fc91d8c8c7794722e9fbe43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 5 Mar 2025 22:46:06 +0100 Subject: [PATCH 13/18] adding OLMoE architecture (#1321) * initial commit * udpate ACKNOWLEDGMENTS.md * adding olmoe to training * clean up * faster generation * remove sanitize method * more clean ups * adding SwitchGLU * clean up * a little faster and adding norm_topk_prob * formated --- ACKNOWLEDGMENTS.md | 2 +- llms/mlx_lm/models/olmoe.py | 217 ++++++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 3 + 3 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 llms/mlx_lm/models/olmoe.py diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 851c995c..c6853710 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals: - Markus Enzweiler: Added the `cvae` examples. - Prince Canuma: Helped add support for `Starcoder2` models. - Shiyu Li: Added the `Segment Anything Model`. -- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`. \ No newline at end of file +- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`. \ No newline at end of file diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py new file mode 100644 index 00000000..b9c0fc69 --- /dev/null +++ b/llms/mlx_lm/models/olmoe.py @@ -0,0 +1,217 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_experts: int + num_experts_per_tok: int + norm_topk_prob: bool = False + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + num_key_value_heads: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) + self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries = self.q_norm(queries) + keys = self.k_norm(keys) + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class OlmoeSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + self.norm_topk_prob = args.norm_topk_prob + + self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False) + self.switch_mlp = SwitchGLU( + args.hidden_size, + args.intermediate_size, + self.num_experts, + bias=args.mlp_bias, + ) + + def __call__(self, x: mx.array) -> mx.array: + B, L, D = x.shape + x_flat = x.reshape(-1, D) + router_logits = self.gate(x_flat) + routing_weights = mx.softmax(router_logits, axis=1, precise=True) + k = self.top_k + indices = mx.stop_gradient( + mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k] + ) + scores = mx.take_along_axis(routing_weights, indices, axis=-1) + if self.norm_topk_prob: + scores = scores / scores.sum(axis=-1, keepdims=True) + y = self.switch_mlp(x_flat, indices) + y = (y * scores[..., None]).sum(axis=-2) + return y.reshape(B, L, D) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.self_attn = Attention(args) + self.mlp = OlmoeSparseMoeBlock(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + x = x + self.self_attn(self.input_layernorm(x), mask, cache) + x = x + self.mlp(self.post_attention_layernorm(x)) + return x + + +class OlmoeModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + mask=None, + ): + h = self.embed_tokens(inputs) + if mask is None: + mask = create_attention_mask(h, cache) + if cache is None: + cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = OlmoeModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + mask=None, + ): + out = self.model(inputs, cache, mask) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + return weights + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index f5df11e3..cc7c6c20 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -98,6 +98,7 @@ def linear_to_lora_layers( "minicpm", "deepseek", "olmo2", + "olmoe", "internlm3", ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) @@ -106,6 +107,8 @@ def linear_to_lora_layers( if model.model_type == "qwen2_moe": keys.add("mlp.gate") keys.add("mlp.shared_expert_gate") + if model.model_type == "olmoe": + keys.add("mlp.gate") elif model.model_type == "gpt_bigcode": keys = set(["attn.c_attn"]) From e15062109568571aec0e2f099533ad580f0fcaf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 5 Mar 2025 22:54:54 +0100 Subject: [PATCH 14/18] Adding multiple optimizers to mlx lm (#1315) * initial commmit * adding more customized YAML configuartion * update YAML example file * Changed the switch to set opt_class * removing muon * using default arguments * udpate --- llms/mlx_lm/examples/lora_config.yaml | 9 +++++++ llms/mlx_lm/lora.py | 34 +++++++++++++++++++++------ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 530272c7..36bc1dff 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -7,6 +7,15 @@ train: true # The fine-tuning method: "lora", "dora", or "full". fine_tune_type: lora +# The Optimizer with its possible inputs +optimizer: adamw +# optimizer_config: +# adamw: +# betas: [0.9, 0.98] +# eps: 1e-6 +# weight_decay: 0.05 +# bias_correction: true + # Directory with {train, valid, test}.jsonl files data: "/path/to/training/data" diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index d32bfe6d..042b40e2 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -43,6 +43,11 @@ CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, "fine_tune_type": "lora", + "optimizer": "adam", + "optimizer_config": { + "adam": {}, + "adamw": {}, + }, "data": "data/", "seed": 0, "num_layers": 16, @@ -95,14 +100,19 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) - + parser.add_argument( + "--optimizer", + type=str, + choices=["adam", "adamw"], + default=None, + help="Optimizer to use for training: adam or adamw", + ) parser.add_argument( "--mask-prompt", action="store_true", help="Mask the prompt in the loss when training", default=None, ) - parser.add_argument( "--num-layers", type=int, @@ -229,11 +239,21 @@ def train_model( ) model.train() - opt = optim.Adam( - learning_rate=( - build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - ) - ) + + # Initialize the selected optimizer + lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate + + optimizer_name = args.optimizer.lower() + optimizer_config = args.optimizer_config.get(optimizer_name, {}) + + if optimizer_name == "adam": + opt_class = optim.Adam + elif optimizer_name == "adamw": + opt_class = optim.AdamW + else: + raise ValueError(f"Unsupported optimizer: {optimizer_name}") + + opt = opt_class(learning_rate=lr, **optimizer_config) # Train model train( From 32d10036de94af07733c247ca44702e8135d068a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 5 Mar 2025 14:00:09 -0800 Subject: [PATCH 15/18] fix flaky test (#1322) --- llms/tests/test_prompt_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index de5694d5..c1860892 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -298,7 +298,7 @@ class TestPromptCache(unittest.TestCase): ): i += 1 self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=3e-2)) if __name__ == "__main__": From 877d2a345b8119ad9ed50e2c273a5064ddd3b48c Mon Sep 17 00:00:00 2001 From: cavit99 <35897738+cavit99@users.noreply.github.com> Date: Thu, 6 Mar 2025 14:49:35 +0000 Subject: [PATCH 16/18] Change DEFAULT_SEED to None for stochastic generation by default (#1323) * Change DEFAULT_SEED to None for stochastic generation by default * Update llms/mlx_lm/chat.py * Update llms/mlx_lm/generate.py --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/chat.py | 12 +++++++++--- llms/mlx_lm/generate.py | 13 ++++++++++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 5c0b78db..d8e1ccb9 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -11,7 +11,7 @@ from .utils import load, stream_generate DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 -DEFAULT_SEED = 0 +DEFAULT_SEED = None DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" @@ -36,7 +36,12 @@ def setup_arg_parser(): parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) - parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--seed", + type=int, + default=DEFAULT_SEED, + help="PRNG seed", + ) parser.add_argument( "--max-kv-size", type=int, @@ -57,7 +62,8 @@ def main(): parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) + if args.seed is not None: + mx.random.seed(args.seed) model, tokenizer = load( args.model, diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index bd11dcf0..7d58da82 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -16,7 +16,7 @@ DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_MIN_P = 0.0 DEFAULT_MIN_TOKENS_TO_KEEP = 1 -DEFAULT_SEED = 0 +DEFAULT_SEED = None DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 @@ -87,7 +87,12 @@ def setup_arg_parser(): default=DEFAULT_MIN_TOKENS_TO_KEEP, help="Minimum tokens to keep for min-p sampling.", ) - parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--seed", + type=int, + default=DEFAULT_SEED, + help="PRNG seed", + ) parser.add_argument( "--ignore-chat-template", action="store_true", @@ -160,7 +165,9 @@ def setup_arg_parser(): def main(): parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) + + if args.seed is not None: + mx.random.seed(args.seed) # Load the prompt cache and metadata if a cache file is provided using_cache = args.prompt_cache_file is not None From 595f5da146bbf305b14fe18d343fe2777aa8a1ba Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 6 Mar 2025 15:35:47 -0800 Subject: [PATCH 17/18] remove lm head if unused (#1324) --- llms/mlx_lm/models/llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7b452ea4..117adf0f 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -196,9 +196,12 @@ class Model(nn.Module): def sanitize(self, weights): # Remove unused precomputed rotary freqs - return { + weights = { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } + if self.args.tie_word_embeddings: + weights.pop("lm_head.weight", None) + return weights @property def layers(self): From d2e02b3aae9741eea6f9c6123624406de3f10015 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 7 Mar 2025 08:35:48 -0800 Subject: [PATCH 18/18] fix mixed quant option (#1326) --- llms/mlx_lm/convert.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 86a96447..f268913b 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -1,27 +1,23 @@ # Copyright © 2023-2024 Apple Inc. import argparse -from enum import Enum -from .utils import convert, mixed_2_6, mixed_3_6 +from . import utils +from .utils import convert - -class MixedQuants(Enum): - mixed_3_6 = "mixed_3_6" - mixed_2_6 = "mixed_2_6" - - @classmethod - def recipe_names(cls): - return [member.name for member in cls] +QUANT_RECIPES = [ + "mixed_2_6", + "mixed_3_6", +] def quant_args(arg): - try: - return MixedQuants[arg].value - except KeyError: + if arg not in QUANT_RECIPES: raise argparse.ArgumentTypeError( - f"Invalid q-recipe {arg!r}. Choose from: {MixedQuants.recipe_names()}" + f"Invalid q-recipe {arg!r}. Choose from: {QUANT_RECIPES}" ) + else: + return getattr(utils, arg) def configure_parser() -> argparse.ArgumentParser: @@ -50,7 +46,7 @@ def configure_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--quant-predicate", - help=f"Mixed-bit quantization recipe. Choices: {MixedQuants.recipe_names()}", + help=f"Mixed-bit quantization recipe. Choices: {QUANT_RECIPES}", type=quant_args, required=False, )