From 65aa2ec84918d4438a73d7504bae2f8e9f0d396b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Mar 2025 12:47:32 -0800 Subject: [PATCH 1/9] use a bool mask for attention (#1319) --- llms/mlx_lm/generate.py | 2 +- llms/mlx_lm/models/base.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index e40332dd..bd11dcf0 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -152,7 +152,7 @@ def setup_arg_parser(): "--num-draft-tokens", type=int, help="Number of tokens to draft when using speculative decoding.", - default=2, + default=3, ) return parser diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index ad7a4a65..8b40effb 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -33,13 +33,13 @@ def create_causal_mask( linds = mx.arange(offset, offset + N) if offset else rinds linds = linds[:, None] rinds = rinds[None] - mask = linds < rinds + mask = linds >= rinds if window_size is not None: - mask = mask | (linds > rinds + window_size) + mask = mask & (linds <= rinds + window_size) if lengths is not None: lengths = lengths[:, None, None, None] - mask = mask | (rinds >= lengths) - return mask * -1e9 + mask = mask & (rinds < lengths) + return mask def create_attention_mask(h: mx.array, cache: Optional[Any] = None): @@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: offset = c.offset mask = create_causal_mask(T, offset, window_size=window_size) - mask = mask.astype(h.dtype) else: mask = None return mask From f621218ff5284306c0f78ea4a34cd22c033e4b9d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Mar 2025 13:53:20 -0800 Subject: [PATCH 2/9] Tool use example (#1316) * tool use example * nits --- llms/mlx_lm/examples/tool_use.py | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 llms/mlx_lm/examples/tool_use.py diff --git a/llms/mlx_lm/examples/tool_use.py b/llms/mlx_lm/examples/tool_use.py new file mode 100644 index 00000000..624b9e5b --- /dev/null +++ b/llms/mlx_lm/examples/tool_use.py @@ -0,0 +1,73 @@ +# Copyright © 2025 Apple Inc. + +import json + +from mlx_lm import generate, load +from mlx_lm.models.cache import make_prompt_cache + +# Specify the checkpoint +checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit" + +# Load the corresponding model and tokenizer +model, tokenizer = load(path_or_hf_repo=checkpoint) + + +# An example tool, make sure to include a docstring and type hints +def multiply(a: float, b: float): + """ + A function that multiplies two numbers + + Args: + a: The first number to multiply + b: The second number to multiply + """ + return a * b + + +tools = {"multiply": multiply} + +# Specify the prompt and conversation history +prompt = "Multiply 12234585 and 48838483920." +messages = [{"role": "user", "content": prompt}] + +prompt = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tools=list(tools.values()) +) + +prompt_cache = make_prompt_cache(model) + +# Generate the initial tool call: +response = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=2048, + verbose=True, + prompt_cache=prompt_cache, +) + +# Parse the tool call: +# (Note, the tool call format is model specific) +tool_open = "" +tool_close = "" +start_tool = response.find(tool_open) + len(tool_open) +end_tool = response.find(tool_close) +tool_call = json.loads(response[start_tool:end_tool].strip()) +tool_result = tools[tool_call["name"]](**tool_call["arguments"]) + +# Put the tool result in the prompt +messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}] +prompt = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, +) + +# Generate the final response: +response = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_tokens=2048, + verbose=True, + prompt_cache=prompt_cache, +) From e7267d30f83bc3f22ff6f0f8132ca0bcd9c38115 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 5 Mar 2025 13:33:15 -0800 Subject: [PATCH 3/9] Distributed support cifar (#1301) --- cifar/README.md | 14 ++++++++ cifar/dataset.py | 11 +++++- cifar/main.py | 91 +++++++++++++++++++++++++++++++----------------- 3 files changed, 84 insertions(+), 32 deletions(-) diff --git a/cifar/README.md b/cifar/README.md index 763e641d..2016200d 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM. At the time of writing, `mlx` doesn't have built-in learning rate schedules. We intend to update this example once these features are added. + +## Distributed training + +The example also supports distributed data parallel training. You can launch a +distributed training as follows: + +```shell +$ cat >hostfile.json +[ + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}, + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]} +] +$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20 +``` diff --git a/cifar/dataset.py b/cifar/dataset.py index 22b229f8..8967591e 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -1,3 +1,4 @@ +import mlx.core as mx import numpy as np from mlx.data.datasets import load_cifar10 @@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None): x = x.astype("float32") / 255.0 return (x - mean) / std + group = mx.distributed.init() + tr_iter = ( tr.shuffle() + .partition_if(group.size() > 1, group.size(), group.rank()) .to_stream() .image_random_h_flip("image", prob=0.5) .pad("image", 0, 4, 4, 0.0) @@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None): ) test = load_cifar10(root=root, train=False) - test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size) + test_iter = ( + test.to_stream() + .partition_if(group.size() > 1, group.size(), group.rank()) + .key_transform("image", normalize) + .batch(batch_size) + ) return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py index 378bc424..ac010636 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed") parser.add_argument("--cpu", action="store_true", help="use cpu only") +def print_zero(group, *args, **kwargs): + if group.rank() != 0: + return + flush = kwargs.pop("flush", True) + print(*args, **kwargs, flush=flush) + + def eval_fn(model, inp, tgt): return mx.mean(mx.argmax(model(inp), axis=1) == tgt) @@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch): acc = mx.mean(mx.argmax(output, axis=1) == tgt) return loss, acc - losses = [] - accs = [] - samples_per_sec = [] + world = mx.distributed.init() + losses = 0 + accuracies = 0 + samples_per_sec = 0 + count = 0 + + def average_stats(stats, count): + if world.size() == 1: + return [s / count for s in stats] + + with mx.stream(mx.cpu): + stats = mx.distributed.all_sum(mx.array(stats)) + count = mx.distributed.all_sum(count) + return (stats / count).tolist() state = [model.state, optimizer.state] @@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch): def step(inp, tgt): train_step_fn = nn.value_and_grad(model, train_step) (loss, acc), grads = train_step_fn(model, inp, tgt) + grads = nn.utils.average_gradients(grads) optimizer.update(model, grads) return loss, acc @@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch): y = mx.array(batch["label"]) tic = time.perf_counter() loss, acc = step(x, y) - mx.eval(state) + mx.eval(loss, acc, state) toc = time.perf_counter() - loss = loss.item() - acc = acc.item() - losses.append(loss) - accs.append(acc) - throughput = x.shape[0] / (toc - tic) - samples_per_sec.append(throughput) + losses += loss.item() + accuracies += acc.item() + samples_per_sec += x.shape[0] / (toc - tic) + count += 1 if batch_counter % 10 == 0: - print( + l, a, s = average_stats( + [losses, accuracies, world.size() * samples_per_sec], + count, + ) + print_zero( + world, " | ".join( ( f"Epoch {epoch:02d} [{batch_counter:03d}]", - f"Train loss {loss:.3f}", - f"Train acc {acc:.3f}", - f"Throughput: {throughput:.2f} images/second", + f"Train loss {l:.3f}", + f"Train acc {a:.3f}", + f"Throughput: {s:.2f} images/second", ) - ) + ), ) - mean_tr_loss = mx.mean(mx.array(losses)) - mean_tr_acc = mx.mean(mx.array(accs)) - samples_per_sec = mx.mean(mx.array(samples_per_sec)) - return mean_tr_loss, mean_tr_acc, samples_per_sec + return average_stats([losses, accuracies, world.size() * samples_per_sec], count) def test_epoch(model, test_iter, epoch): - accs = [] + accuracies = 0 + count = 0 for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["image"]) y = mx.array(batch["label"]) acc = eval_fn(model, x, y) - acc_value = acc.item() - accs.append(acc_value) - mean_acc = mx.mean(mx.array(accs)) - return mean_acc + accuracies += acc.item() + count += 1 + + with mx.stream(mx.cpu): + accuracies = mx.distributed.all_sum(accuracies) + count = mx.distributed.all_sum(count) + return (accuracies / count).item() def main(args): mx.random.seed(args.seed) + # Initialize the distributed group and report the nodes that showed up + world = mx.distributed.init() + if world.size() > 1: + print(f"Starting rank {world.rank()} of {world.size()}", flush=True) + model = getattr(resnet, args.arch)() - print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) + print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M") optimizer = optim.Adam(learning_rate=args.lr) train_data, test_data = get_cifar10(args.batch_size) for epoch in range(args.epochs): tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) - print( + print_zero( + world, " | ".join( ( f"Epoch: {epoch}", - f"avg. Train loss {tr_loss.item():.3f}", - f"avg. Train acc {tr_acc.item():.3f}", - f"Throughput: {throughput.item():.2f} images/sec", + f"avg. Train loss {tr_loss:.3f}", + f"avg. Train acc {tr_acc:.3f}", + f"Throughput: {throughput:.2f} images/sec", ) - ) + ), ) test_acc = test_epoch(model, test_data, epoch) - print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}") + print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}") train_data.reset() test_data.reset() From 56d2db23e1348f046fc91d8c8c7794722e9fbe43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 5 Mar 2025 22:46:06 +0100 Subject: [PATCH 4/9] adding OLMoE architecture (#1321) * initial commit * udpate ACKNOWLEDGMENTS.md * adding olmoe to training * clean up * faster generation * remove sanitize method * more clean ups * adding SwitchGLU * clean up * a little faster and adding norm_topk_prob * formated --- ACKNOWLEDGMENTS.md | 2 +- llms/mlx_lm/models/olmoe.py | 217 ++++++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 3 + 3 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 llms/mlx_lm/models/olmoe.py diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 851c995c..c6853710 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals: - Markus Enzweiler: Added the `cvae` examples. - Prince Canuma: Helped add support for `Starcoder2` models. - Shiyu Li: Added the `Segment Anything Model`. -- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`. \ No newline at end of file +- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`. \ No newline at end of file diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py new file mode 100644 index 00000000..b9c0fc69 --- /dev/null +++ b/llms/mlx_lm/models/olmoe.py @@ -0,0 +1,217 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_experts: int + num_experts_per_tok: int + norm_topk_prob: bool = False + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + num_key_value_heads: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) + self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries = self.q_norm(queries) + keys = self.k_norm(keys) + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class OlmoeSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_experts = args.num_experts + self.top_k = args.num_experts_per_tok + self.norm_topk_prob = args.norm_topk_prob + + self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False) + self.switch_mlp = SwitchGLU( + args.hidden_size, + args.intermediate_size, + self.num_experts, + bias=args.mlp_bias, + ) + + def __call__(self, x: mx.array) -> mx.array: + B, L, D = x.shape + x_flat = x.reshape(-1, D) + router_logits = self.gate(x_flat) + routing_weights = mx.softmax(router_logits, axis=1, precise=True) + k = self.top_k + indices = mx.stop_gradient( + mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k] + ) + scores = mx.take_along_axis(routing_weights, indices, axis=-1) + if self.norm_topk_prob: + scores = scores / scores.sum(axis=-1, keepdims=True) + y = self.switch_mlp(x_flat, indices) + y = (y * scores[..., None]).sum(axis=-2) + return y.reshape(B, L, D) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.self_attn = Attention(args) + self.mlp = OlmoeSparseMoeBlock(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + x = x + self.self_attn(self.input_layernorm(x), mask, cache) + x = x + self.mlp(self.post_attention_layernorm(x)) + return x + + +class OlmoeModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + mask=None, + ): + h = self.embed_tokens(inputs) + if mask is None: + mask = create_attention_mask(h, cache) + if cache is None: + cache = [None] * len(self.layers) + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = OlmoeModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + mask=None, + ): + out = self.model(inputs, cache, mask) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + return weights + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index f5df11e3..cc7c6c20 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -98,6 +98,7 @@ def linear_to_lora_layers( "minicpm", "deepseek", "olmo2", + "olmoe", "internlm3", ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) @@ -106,6 +107,8 @@ def linear_to_lora_layers( if model.model_type == "qwen2_moe": keys.add("mlp.gate") keys.add("mlp.shared_expert_gate") + if model.model_type == "olmoe": + keys.add("mlp.gate") elif model.model_type == "gpt_bigcode": keys = set(["attn.c_attn"]) From e15062109568571aec0e2f099533ad580f0fcaf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 5 Mar 2025 22:54:54 +0100 Subject: [PATCH 5/9] Adding multiple optimizers to mlx lm (#1315) * initial commmit * adding more customized YAML configuartion * update YAML example file * Changed the switch to set opt_class * removing muon * using default arguments * udpate --- llms/mlx_lm/examples/lora_config.yaml | 9 +++++++ llms/mlx_lm/lora.py | 34 +++++++++++++++++++++------ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 530272c7..36bc1dff 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -7,6 +7,15 @@ train: true # The fine-tuning method: "lora", "dora", or "full". fine_tune_type: lora +# The Optimizer with its possible inputs +optimizer: adamw +# optimizer_config: +# adamw: +# betas: [0.9, 0.98] +# eps: 1e-6 +# weight_decay: 0.05 +# bias_correction: true + # Directory with {train, valid, test}.jsonl files data: "/path/to/training/data" diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index d32bfe6d..042b40e2 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -43,6 +43,11 @@ CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, "fine_tune_type": "lora", + "optimizer": "adam", + "optimizer_config": { + "adam": {}, + "adamw": {}, + }, "data": "data/", "seed": 0, "num_layers": 16, @@ -95,14 +100,19 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) - + parser.add_argument( + "--optimizer", + type=str, + choices=["adam", "adamw"], + default=None, + help="Optimizer to use for training: adam or adamw", + ) parser.add_argument( "--mask-prompt", action="store_true", help="Mask the prompt in the loss when training", default=None, ) - parser.add_argument( "--num-layers", type=int, @@ -229,11 +239,21 @@ def train_model( ) model.train() - opt = optim.Adam( - learning_rate=( - build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - ) - ) + + # Initialize the selected optimizer + lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate + + optimizer_name = args.optimizer.lower() + optimizer_config = args.optimizer_config.get(optimizer_name, {}) + + if optimizer_name == "adam": + opt_class = optim.Adam + elif optimizer_name == "adamw": + opt_class = optim.AdamW + else: + raise ValueError(f"Unsupported optimizer: {optimizer_name}") + + opt = opt_class(learning_rate=lr, **optimizer_config) # Train model train( From 32d10036de94af07733c247ca44702e8135d068a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 5 Mar 2025 14:00:09 -0800 Subject: [PATCH 6/9] fix flaky test (#1322) --- llms/tests/test_prompt_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index de5694d5..c1860892 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -298,7 +298,7 @@ class TestPromptCache(unittest.TestCase): ): i += 1 self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=3e-2)) if __name__ == "__main__": From 877d2a345b8119ad9ed50e2c273a5064ddd3b48c Mon Sep 17 00:00:00 2001 From: cavit99 <35897738+cavit99@users.noreply.github.com> Date: Thu, 6 Mar 2025 14:49:35 +0000 Subject: [PATCH 7/9] Change DEFAULT_SEED to None for stochastic generation by default (#1323) * Change DEFAULT_SEED to None for stochastic generation by default * Update llms/mlx_lm/chat.py * Update llms/mlx_lm/generate.py --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/chat.py | 12 +++++++++--- llms/mlx_lm/generate.py | 13 ++++++++++--- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 5c0b78db..d8e1ccb9 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -11,7 +11,7 @@ from .utils import load, stream_generate DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 -DEFAULT_SEED = 0 +DEFAULT_SEED = None DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" @@ -36,7 +36,12 @@ def setup_arg_parser(): parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) - parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--seed", + type=int, + default=DEFAULT_SEED, + help="PRNG seed", + ) parser.add_argument( "--max-kv-size", type=int, @@ -57,7 +62,8 @@ def main(): parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) + if args.seed is not None: + mx.random.seed(args.seed) model, tokenizer = load( args.model, diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index bd11dcf0..7d58da82 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -16,7 +16,7 @@ DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_MIN_P = 0.0 DEFAULT_MIN_TOKENS_TO_KEEP = 1 -DEFAULT_SEED = 0 +DEFAULT_SEED = None DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 @@ -87,7 +87,12 @@ def setup_arg_parser(): default=DEFAULT_MIN_TOKENS_TO_KEEP, help="Minimum tokens to keep for min-p sampling.", ) - parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--seed", + type=int, + default=DEFAULT_SEED, + help="PRNG seed", + ) parser.add_argument( "--ignore-chat-template", action="store_true", @@ -160,7 +165,9 @@ def setup_arg_parser(): def main(): parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) + + if args.seed is not None: + mx.random.seed(args.seed) # Load the prompt cache and metadata if a cache file is provided using_cache = args.prompt_cache_file is not None From 595f5da146bbf305b14fe18d343fe2777aa8a1ba Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 6 Mar 2025 15:35:47 -0800 Subject: [PATCH 8/9] remove lm head if unused (#1324) --- llms/mlx_lm/models/llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7b452ea4..117adf0f 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -196,9 +196,12 @@ class Model(nn.Module): def sanitize(self, weights): # Remove unused precomputed rotary freqs - return { + weights = { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } + if self.args.tie_word_embeddings: + weights.pop("lm_head.weight", None) + return weights @property def layers(self): From d2e02b3aae9741eea6f9c6123624406de3f10015 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 7 Mar 2025 08:35:48 -0800 Subject: [PATCH 9/9] fix mixed quant option (#1326) --- llms/mlx_lm/convert.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 86a96447..f268913b 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -1,27 +1,23 @@ # Copyright © 2023-2024 Apple Inc. import argparse -from enum import Enum -from .utils import convert, mixed_2_6, mixed_3_6 +from . import utils +from .utils import convert - -class MixedQuants(Enum): - mixed_3_6 = "mixed_3_6" - mixed_2_6 = "mixed_2_6" - - @classmethod - def recipe_names(cls): - return [member.name for member in cls] +QUANT_RECIPES = [ + "mixed_2_6", + "mixed_3_6", +] def quant_args(arg): - try: - return MixedQuants[arg].value - except KeyError: + if arg not in QUANT_RECIPES: raise argparse.ArgumentTypeError( - f"Invalid q-recipe {arg!r}. Choose from: {MixedQuants.recipe_names()}" + f"Invalid q-recipe {arg!r}. Choose from: {QUANT_RECIPES}" ) + else: + return getattr(utils, arg) def configure_parser() -> argparse.ArgumentParser: @@ -50,7 +46,7 @@ def configure_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--quant-predicate", - help=f"Mixed-bit quantization recipe. Choices: {MixedQuants.recipe_names()}", + help=f"Mixed-bit quantization recipe. Choices: {QUANT_RECIPES}", type=quant_args, required=False, )