diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md
index 851c995c..c6853710 100644
--- a/ACKNOWLEDGMENTS.md
+++ b/ACKNOWLEDGMENTS.md
@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
-- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
\ No newline at end of file
+- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`.
\ No newline at end of file
diff --git a/cifar/README.md b/cifar/README.md
index 763e641d..2016200d 100644
--- a/cifar/README.md
+++ b/cifar/README.md
@@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
We intend to update this example once these features are added.
+
+## Distributed training
+
+The example also supports distributed data parallel training. You can launch a
+distributed training as follows:
+
+```shell
+$ cat >hostfile.json
+[
+ {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
+ {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
+]
+$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
+```
diff --git a/cifar/dataset.py b/cifar/dataset.py
index 22b229f8..8967591e 100644
--- a/cifar/dataset.py
+++ b/cifar/dataset.py
@@ -1,3 +1,4 @@
+import mlx.core as mx
import numpy as np
from mlx.data.datasets import load_cifar10
@@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None):
x = x.astype("float32") / 255.0
return (x - mean) / std
+ group = mx.distributed.init()
+
tr_iter = (
tr.shuffle()
+ .partition_if(group.size() > 1, group.size(), group.rank())
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
@@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None):
)
test = load_cifar10(root=root, train=False)
- test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
+ test_iter = (
+ test.to_stream()
+ .partition_if(group.size() > 1, group.size(), group.rank())
+ .key_transform("image", normalize)
+ .batch(batch_size)
+ )
return tr_iter, test_iter
diff --git a/cifar/main.py b/cifar/main.py
index 378bc424..ac010636 100644
--- a/cifar/main.py
+++ b/cifar/main.py
@@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")
+def print_zero(group, *args, **kwargs):
+ if group.rank() != 0:
+ return
+ flush = kwargs.pop("flush", True)
+ print(*args, **kwargs, flush=flush)
+
+
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
@@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch):
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc
- losses = []
- accs = []
- samples_per_sec = []
+ world = mx.distributed.init()
+ losses = 0
+ accuracies = 0
+ samples_per_sec = 0
+ count = 0
+
+ def average_stats(stats, count):
+ if world.size() == 1:
+ return [s / count for s in stats]
+
+ with mx.stream(mx.cpu):
+ stats = mx.distributed.all_sum(mx.array(stats))
+ count = mx.distributed.all_sum(count)
+ return (stats / count).tolist()
state = [model.state, optimizer.state]
@@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
def step(inp, tgt):
train_step_fn = nn.value_and_grad(model, train_step)
(loss, acc), grads = train_step_fn(model, inp, tgt)
+ grads = nn.utils.average_gradients(grads)
optimizer.update(model, grads)
return loss, acc
@@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch):
y = mx.array(batch["label"])
tic = time.perf_counter()
loss, acc = step(x, y)
- mx.eval(state)
+ mx.eval(loss, acc, state)
toc = time.perf_counter()
- loss = loss.item()
- acc = acc.item()
- losses.append(loss)
- accs.append(acc)
- throughput = x.shape[0] / (toc - tic)
- samples_per_sec.append(throughput)
+ losses += loss.item()
+ accuracies += acc.item()
+ samples_per_sec += x.shape[0] / (toc - tic)
+ count += 1
if batch_counter % 10 == 0:
- print(
+ l, a, s = average_stats(
+ [losses, accuracies, world.size() * samples_per_sec],
+ count,
+ )
+ print_zero(
+ world,
" | ".join(
(
f"Epoch {epoch:02d} [{batch_counter:03d}]",
- f"Train loss {loss:.3f}",
- f"Train acc {acc:.3f}",
- f"Throughput: {throughput:.2f} images/second",
+ f"Train loss {l:.3f}",
+ f"Train acc {a:.3f}",
+ f"Throughput: {s:.2f} images/second",
)
- )
+ ),
)
- mean_tr_loss = mx.mean(mx.array(losses))
- mean_tr_acc = mx.mean(mx.array(accs))
- samples_per_sec = mx.mean(mx.array(samples_per_sec))
- return mean_tr_loss, mean_tr_acc, samples_per_sec
+ return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
def test_epoch(model, test_iter, epoch):
- accs = []
+ accuracies = 0
+ count = 0
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
acc = eval_fn(model, x, y)
- acc_value = acc.item()
- accs.append(acc_value)
- mean_acc = mx.mean(mx.array(accs))
- return mean_acc
+ accuracies += acc.item()
+ count += 1
+
+ with mx.stream(mx.cpu):
+ accuracies = mx.distributed.all_sum(accuracies)
+ count = mx.distributed.all_sum(count)
+ return (accuracies / count).item()
def main(args):
mx.random.seed(args.seed)
+ # Initialize the distributed group and report the nodes that showed up
+ world = mx.distributed.init()
+ if world.size() > 1:
+ print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
+
model = getattr(resnet, args.arch)()
- print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
+ print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")
optimizer = optim.Adam(learning_rate=args.lr)
train_data, test_data = get_cifar10(args.batch_size)
for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
- print(
+ print_zero(
+ world,
" | ".join(
(
f"Epoch: {epoch}",
- f"avg. Train loss {tr_loss.item():.3f}",
- f"avg. Train acc {tr_acc.item():.3f}",
- f"Throughput: {throughput.item():.2f} images/sec",
+ f"avg. Train loss {tr_loss:.3f}",
+ f"avg. Train acc {tr_acc:.3f}",
+ f"Throughput: {throughput:.2f} images/sec",
)
- )
+ ),
)
test_acc = test_epoch(model, test_data, epoch)
- print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
+ print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")
train_data.reset()
test_data.reset()
diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py
index 89e6cd00..839089b6 100644
--- a/llms/mlx_lm/_version.py
+++ b/llms/mlx_lm/_version.py
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
-__version__ = "0.21.5"
+__version__ = "0.21.6"
diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py
index e52ad10d..d8e1ccb9 100644
--- a/llms/mlx_lm/chat.py
+++ b/llms/mlx_lm/chat.py
@@ -11,7 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
-DEFAULT_SEED = 0
+DEFAULT_SEED = None
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@@ -36,7 +36,12 @@ def setup_arg_parser():
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
- parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=DEFAULT_SEED,
+ help="PRNG seed",
+ )
parser.add_argument(
"--max-kv-size",
type=int,
@@ -57,7 +62,8 @@ def main():
parser = setup_arg_parser()
args = parser.parse_args()
- mx.random.seed(args.seed)
+ if args.seed is not None:
+ mx.random.seed(args.seed)
model, tokenizer = load(
args.model,
@@ -65,12 +71,25 @@ def main():
tokenizer_config={"trust_remote_code": True},
)
- print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
+ def print_help():
+ print("The command list:")
+ print("- 'q' to exit")
+ print("- 'r' to reset the chat")
+ print("- 'h' to display these commands")
+
+ print(f"[INFO] Starting chat session with {args.model}.")
+ print_help()
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
break
+ if query == "r":
+ prompt_cache = make_prompt_cache(model, args.max_kv_size)
+ continue
+ if query == "h":
+ print_help()
+ continue
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py
index 9bac77a5..f268913b 100644
--- a/llms/mlx_lm/convert.py
+++ b/llms/mlx_lm/convert.py
@@ -2,8 +2,23 @@
import argparse
+from . import utils
from .utils import convert
+QUANT_RECIPES = [
+ "mixed_2_6",
+ "mixed_3_6",
+]
+
+
+def quant_args(arg):
+ if arg not in QUANT_RECIPES:
+ raise argparse.ArgumentTypeError(
+ f"Invalid q-recipe {arg!r}. Choose from: {QUANT_RECIPES}"
+ )
+ else:
+ return getattr(utils, arg)
+
def configure_parser() -> argparse.ArgumentParser:
"""
@@ -29,6 +44,12 @@ def configure_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
+ parser.add_argument(
+ "--quant-predicate",
+ help=f"Mixed-bit quantization recipe. Choices: {QUANT_RECIPES}",
+ type=quant_args,
+ required=False,
+ )
parser.add_argument(
"--dtype",
help="Type to save the non-quantized parameters.",
diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py
index 2f35ade2..cd6de7ec 100644
--- a/llms/mlx_lm/evaluate.py
+++ b/llms/mlx_lm/evaluate.py
@@ -289,17 +289,15 @@ class MLXLM(LM):
contexts, options = zip(*[req.args for req in requests])
# contrary to the doc the second element of the tuple contains
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
- keys = list(options[0].keys())
- assert "until" in keys
- untils = [x["until"] for x in options]
completions = []
- for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
+ for context, opt in tqdm(zip(contexts, options), total=len(contexts)):
+ until = opt["until"]
context = self.tokenizer.encode(
context, add_special_tokens=not self.use_chat_template
)
max_tokens = min(
- self._max_tokens,
+ opt.get("max_gen_tokens", self._max_tokens),
self.tokenizer.model_max_length - len(context),
)
text = ""
@@ -334,9 +332,9 @@ def main():
)
parser.add_argument(
"--limit",
- default=1.0,
+ default=100,
help="Limit the number of examples per task.",
- type=float,
+ type=int,
)
parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument(
diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml
index 530272c7..36bc1dff 100644
--- a/llms/mlx_lm/examples/lora_config.yaml
+++ b/llms/mlx_lm/examples/lora_config.yaml
@@ -7,6 +7,15 @@ train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
+# The Optimizer with its possible inputs
+optimizer: adamw
+# optimizer_config:
+# adamw:
+# betas: [0.9, 0.98]
+# eps: 1e-6
+# weight_decay: 0.05
+# bias_correction: true
+
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"
diff --git a/llms/mlx_lm/examples/tool_use.py b/llms/mlx_lm/examples/tool_use.py
new file mode 100644
index 00000000..624b9e5b
--- /dev/null
+++ b/llms/mlx_lm/examples/tool_use.py
@@ -0,0 +1,73 @@
+# Copyright © 2025 Apple Inc.
+
+import json
+
+from mlx_lm import generate, load
+from mlx_lm.models.cache import make_prompt_cache
+
+# Specify the checkpoint
+checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit"
+
+# Load the corresponding model and tokenizer
+model, tokenizer = load(path_or_hf_repo=checkpoint)
+
+
+# An example tool, make sure to include a docstring and type hints
+def multiply(a: float, b: float):
+ """
+ A function that multiplies two numbers
+
+ Args:
+ a: The first number to multiply
+ b: The second number to multiply
+ """
+ return a * b
+
+
+tools = {"multiply": multiply}
+
+# Specify the prompt and conversation history
+prompt = "Multiply 12234585 and 48838483920."
+messages = [{"role": "user", "content": prompt}]
+
+prompt = tokenizer.apply_chat_template(
+ messages, add_generation_prompt=True, tools=list(tools.values())
+)
+
+prompt_cache = make_prompt_cache(model)
+
+# Generate the initial tool call:
+response = generate(
+ model=model,
+ tokenizer=tokenizer,
+ prompt=prompt,
+ max_tokens=2048,
+ verbose=True,
+ prompt_cache=prompt_cache,
+)
+
+# Parse the tool call:
+# (Note, the tool call format is model specific)
+tool_open = ""
+tool_close = ""
+start_tool = response.find(tool_open) + len(tool_open)
+end_tool = response.find(tool_close)
+tool_call = json.loads(response[start_tool:end_tool].strip())
+tool_result = tools[tool_call["name"]](**tool_call["arguments"])
+
+# Put the tool result in the prompt
+messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}]
+prompt = tokenizer.apply_chat_template(
+ messages,
+ add_generation_prompt=True,
+)
+
+# Generate the final response:
+response = generate(
+ model=model,
+ tokenizer=tokenizer,
+ prompt=prompt,
+ max_tokens=2048,
+ verbose=True,
+ prompt_cache=prompt_cache,
+)
diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py
index d8f97e5e..7d58da82 100644
--- a/llms/mlx_lm/generate.py
+++ b/llms/mlx_lm/generate.py
@@ -16,7 +16,7 @@ DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
-DEFAULT_SEED = 0
+DEFAULT_SEED = None
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
@@ -60,6 +60,11 @@ def setup_arg_parser():
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
+ parser.add_argument(
+ "--prefill-response",
+ default=None,
+ help="Prefill response to be used for the chat template",
+ )
parser.add_argument(
"--max-tokens",
"-m",
@@ -82,7 +87,12 @@ def setup_arg_parser():
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
- parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=DEFAULT_SEED,
+ help="PRNG seed",
+ )
parser.add_argument(
"--ignore-chat-template",
action="store_true",
@@ -147,7 +157,7 @@ def setup_arg_parser():
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
- default=2,
+ default=3,
)
return parser
@@ -155,7 +165,9 @@ def setup_arg_parser():
def main():
parser = setup_arg_parser()
args = parser.parse_args()
- mx.random.seed(args.seed)
+
+ if args.seed is not None:
+ mx.random.seed(args.seed)
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None
@@ -219,10 +231,14 @@ def main():
messages = []
messages.append({"role": "user", "content": prompt})
+ has_prefill = args.prefill_response is not None
+ if has_prefill:
+ messages.append({"role": "assistant", "content": args.prefill_response})
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
- add_generation_prompt=True,
+ continue_final_message=has_prefill,
+ add_generation_prompt=not has_prefill,
**template_kwargs,
)
@@ -233,7 +249,8 @@ def main():
test_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
- add_generation_prompt=True,
+ continue_final_message=has_prefill,
+ add_generation_prompt=not has_prefill,
)
prompt = prompt[test_prompt.index("") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False)
diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py
index def3b6dd..042b40e2 100644
--- a/llms/mlx_lm/lora.py
+++ b/llms/mlx_lm/lora.py
@@ -43,6 +43,11 @@ CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
+ "optimizer": "adam",
+ "optimizer_config": {
+ "adam": {},
+ "adamw": {},
+ },
"data": "data/",
"seed": 0,
"num_layers": 16,
@@ -62,6 +67,7 @@ CONFIG_DEFAULTS = {
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
+ "mask_prompt": False,
}
@@ -94,14 +100,19 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
-
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ choices=["adam", "adamw"],
+ default=None,
+ help="Optimizer to use for training: adam or adamw",
+ )
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
- default=False,
+ default=None,
)
-
parser.add_argument(
"--num-layers",
type=int,
@@ -228,11 +239,21 @@ def train_model(
)
model.train()
- opt = optim.Adam(
- learning_rate=(
- build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
- )
- )
+
+ # Initialize the selected optimizer
+ lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
+
+ optimizer_name = args.optimizer.lower()
+ optimizer_config = args.optimizer_config.get(optimizer_name, {})
+
+ if optimizer_name == "adam":
+ opt_class = optim.Adam
+ elif optimizer_name == "adamw":
+ opt_class = optim.AdamW
+ else:
+ raise ValueError(f"Unsupported optimizer: {optimizer_name}")
+
+ opt = opt_class(learning_rate=lr, **optimizer_config)
# Train model
train(
diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py
index 9827f3dc..c06de6b3 100644
--- a/llms/mlx_lm/manage.py
+++ b/llms/mlx_lm/manage.py
@@ -2,7 +2,22 @@ import argparse
from typing import List, Union
from huggingface_hub import scan_cache_dir
-from transformers.commands.user import tabulate
+
+
+def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
+ """
+ Inspired by:
+ - stackoverflow.com/a/8356620/593036
+ - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
+ """
+ col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
+ row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
+ lines = []
+ lines.append(row_format.format(*headers))
+ lines.append(row_format.format(*["-" * w for w in col_widths]))
+ for row in rows:
+ lines.append(row_format.format(*row))
+ return "\n".join(lines)
def ask_for_confirmation(message: str) -> bool:
diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py
index ad7a4a65..8b40effb 100644
--- a/llms/mlx_lm/models/base.py
+++ b/llms/mlx_lm/models/base.py
@@ -33,13 +33,13 @@ def create_causal_mask(
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
rinds = rinds[None]
- mask = linds < rinds
+ mask = linds >= rinds
if window_size is not None:
- mask = mask | (linds > rinds + window_size)
+ mask = mask & (linds <= rinds + window_size)
if lengths is not None:
lengths = lengths[:, None, None, None]
- mask = mask | (rinds >= lengths)
- return mask * -1e9
+ mask = mask & (rinds < lengths)
+ return mask
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
@@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
else:
offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size)
- mask = mask.astype(h.dtype)
else:
mask = None
return mask
diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py
index 47e17236..5cd40a0d 100644
--- a/llms/mlx_lm/models/deepseek_v3.py
+++ b/llms/mlx_lm/models/deepseek_v3.py
@@ -181,30 +181,37 @@ class DeepseekV3Attention(nn.Module):
bias=config.attention_bias,
)
- mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
- scaling_factor = self.config.rope_scaling["factor"]
- if mscale_all_dim:
- mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
- self.scale = self.scale * mscale * mscale
+ if self.config.rope_scaling is not None:
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
+ scaling_factor = self.config.rope_scaling["factor"]
+ if mscale_all_dim:
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
+ self.scale = self.scale * mscale * mscale
- rope_kwargs = {
- key: self.config.rope_scaling[key]
- for key in [
- "original_max_position_embeddings",
- "beta_fast",
- "beta_slow",
- "mscale",
- "mscale_all_dim",
- ]
- if key in self.config.rope_scaling
- }
- self.rope = DeepseekV3YarnRotaryEmbedding(
- dim=self.qk_rope_head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- **rope_kwargs,
- )
+ rope_kwargs = {
+ key: self.config.rope_scaling[key]
+ for key in [
+ "original_max_position_embeddings",
+ "beta_fast",
+ "beta_slow",
+ "mscale",
+ "mscale_all_dim",
+ ]
+ if key in self.config.rope_scaling
+ }
+ self.rope = DeepseekV3YarnRotaryEmbedding(
+ dim=self.qk_rope_head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ **rope_kwargs,
+ )
+ else:
+ self.rope = nn.RoPE(
+ dims=self.qk_rope_head_dim,
+ base=self.rope_theta,
+ traditional=True,
+ )
def __call__(
self,
@@ -487,8 +494,12 @@ class Model(nn.Module):
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
- # Remove multi-token prediction layer
- return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")}
+ # Remove multi-token prediction layer and any unused precomputed rotary freqs
+ return {
+ k: v
+ for k, v in weights.items()
+ if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k
+ }
@property
def layers(self):
diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py
index 7b452ea4..117adf0f 100644
--- a/llms/mlx_lm/models/llama.py
+++ b/llms/mlx_lm/models/llama.py
@@ -196,9 +196,12 @@ class Model(nn.Module):
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
- return {
+ weights = {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
+ if self.args.tie_word_embeddings:
+ weights.pop("lm_head.weight", None)
+ return weights
@property
def layers(self):
diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py
new file mode 100644
index 00000000..b9c0fc69
--- /dev/null
+++ b/llms/mlx_lm/models/olmoe.py
@@ -0,0 +1,217 @@
+# Copyright © 2023-2024 Apple Inc.
+
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Union
+
+import mlx.core as mx
+import mlx.nn as nn
+
+from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
+from .rope_utils import initialize_rope
+from .switch_layers import SwitchGLU
+
+
+@dataclass
+class ModelArgs(BaseModelArgs):
+ model_type: str
+ hidden_size: int
+ num_hidden_layers: int
+ intermediate_size: int
+ num_attention_heads: int
+ rms_norm_eps: float
+ vocab_size: int
+ num_experts: int
+ num_experts_per_tok: int
+ norm_topk_prob: bool = False
+ head_dim: Optional[int] = None
+ max_position_embeddings: Optional[int] = None
+ num_key_value_heads: Optional[int] = None
+ attention_bias: bool = False
+ mlp_bias: bool = False
+ rope_theta: float = 10000
+ rope_traditional: bool = False
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
+ tie_word_embeddings: bool = True
+
+ def __post_init__(self):
+ if self.num_key_value_heads is None:
+ self.num_key_value_heads = self.num_attention_heads
+
+
+class Attention(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+
+ dim = args.hidden_size
+ self.n_heads = n_heads = args.num_attention_heads
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
+
+ self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
+
+ self.scale = head_dim**-0.5
+
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
+
+ self.rope = initialize_rope(
+ self.head_dim,
+ args.rope_theta,
+ args.rope_traditional,
+ args.rope_scaling,
+ args.max_position_embeddings,
+ )
+
+ self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
+ self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[Any] = None,
+ ) -> mx.array:
+ B, L, D = x.shape
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
+ queries = self.q_norm(queries)
+ keys = self.k_norm(keys)
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
+ if cache is not None:
+ queries = self.rope(queries, offset=cache.offset)
+ keys = self.rope(keys, offset=cache.offset)
+ keys, values = cache.update_and_fetch(keys, values)
+ else:
+ queries = self.rope(queries)
+ keys = self.rope(keys)
+ output = scaled_dot_product_attention(
+ queries, keys, values, cache=cache, scale=self.scale, mask=mask
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
+ return self.o_proj(output)
+
+
+class OlmoeSparseMoeBlock(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.num_experts = args.num_experts
+ self.top_k = args.num_experts_per_tok
+ self.norm_topk_prob = args.norm_topk_prob
+
+ self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False)
+ self.switch_mlp = SwitchGLU(
+ args.hidden_size,
+ args.intermediate_size,
+ self.num_experts,
+ bias=args.mlp_bias,
+ )
+
+ def __call__(self, x: mx.array) -> mx.array:
+ B, L, D = x.shape
+ x_flat = x.reshape(-1, D)
+ router_logits = self.gate(x_flat)
+ routing_weights = mx.softmax(router_logits, axis=1, precise=True)
+ k = self.top_k
+ indices = mx.stop_gradient(
+ mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k]
+ )
+ scores = mx.take_along_axis(routing_weights, indices, axis=-1)
+ if self.norm_topk_prob:
+ scores = scores / scores.sum(axis=-1, keepdims=True)
+ y = self.switch_mlp(x_flat, indices)
+ y = (y * scores[..., None]).sum(axis=-2)
+ return y.reshape(B, L, D)
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.self_attn = Attention(args)
+ self.mlp = OlmoeSparseMoeBlock(args)
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+ self.post_attention_layernorm = nn.RMSNorm(
+ args.hidden_size, eps=args.rms_norm_eps
+ )
+
+ def __call__(
+ self,
+ x: mx.array,
+ mask: Optional[mx.array] = None,
+ cache: Optional[Any] = None,
+ ) -> mx.array:
+ x = x + self.self_attn(self.input_layernorm(x), mask, cache)
+ x = x + self.mlp(self.post_attention_layernorm(x))
+ return x
+
+
+class OlmoeModel(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+ self.vocab_size = args.vocab_size
+ self.num_hidden_layers = args.num_hidden_layers
+ assert self.vocab_size > 0
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
+ self.layers = [
+ TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
+ ]
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ mask=None,
+ ):
+ h = self.embed_tokens(inputs)
+ if mask is None:
+ mask = create_attention_mask(h, cache)
+ if cache is None:
+ cache = [None] * len(self.layers)
+ for layer, c in zip(self.layers, cache):
+ h = layer(h, mask, cache=c)
+ return self.norm(h)
+
+
+class Model(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ self.args = args
+ self.model_type = args.model_type
+ self.model = OlmoeModel(args)
+ if not args.tie_word_embeddings:
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ cache=None,
+ mask=None,
+ ):
+ out = self.model(inputs, cache, mask)
+ if self.args.tie_word_embeddings:
+ out = self.model.embed_tokens.as_linear(out)
+ else:
+ out = self.lm_head(out)
+ return out
+
+ def sanitize(self, weights):
+ if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
+ return weights
+ for l in range(self.args.num_hidden_layers):
+ prefix = f"model.layers.{l}"
+ for n in ["up_proj", "down_proj", "gate_proj"]:
+ for k in ["weight", "scales", "biases"]:
+ if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
+ to_join = [
+ weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
+ for e in range(self.args.num_experts)
+ ]
+ weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
+ return weights
+
+ @property
+ def layers(self):
+ return self.model.layers
diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py
index d1c21e25..63e985de 100644
--- a/llms/mlx_lm/models/phi3.py
+++ b/llms/mlx_lm/models/phi3.py
@@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
+ partial_rotary_factor: float = 1.0
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
+ tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
@@ -59,9 +61,10 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
+ rope_dim = int(head_dim * args.partial_rotary_factor)
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding(
- head_dim,
+ rope_dim,
base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
@@ -74,7 +77,7 @@ class Attention(nn.Module):
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
- head_dim,
+ rope_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
@@ -190,7 +193,8 @@ class Model(nn.Module):
super().__init__()
self.model_type = args.model_type
self.model = Phi3Model(args)
- self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
+ if not args.tie_word_embeddings:
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args
def __call__(
@@ -200,7 +204,11 @@ class Model(nn.Module):
cache=None,
):
out = self.model(inputs, mask, cache)
- return self.lm_head(out)
+ if self.args.tie_word_embeddings:
+ out = self.model.embed_tokens.as_linear(out)
+ else:
+ out = self.lm_head(out)
+ return out
@property
def layers(self):
diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py
new file mode 100644
index 00000000..657fa02e
--- /dev/null
+++ b/llms/mlx_lm/models/plamo2.py
@@ -0,0 +1,608 @@
+# Copyright © 2025 Apple Inc.
+
+import math
+from dataclasses import dataclass
+from typing import Any, Optional
+
+import mlx.core as mx
+import mlx.nn as nn
+from mlx_lm.models.base import BaseModelArgs, create_attention_mask
+
+from .cache import KVCache, MambaCache
+
+
+@dataclass
+class ModelArgs(BaseModelArgs):
+ model_type: str = "plamo2"
+ hidden_size: int = 4096
+ num_hidden_layers: int = 32
+ rms_norm_eps: float = 1e-6
+ tie_word_embeddings: bool = True
+ num_attention_heads: int = 32
+ num_key_value_heads: int = 4
+ hidden_size_per_head: int = 128
+ max_position_embeddings: int = 2048
+ attention_window_size: int = 2048
+ full_attention_idx: Optional[list[int]] = None
+ mamba_d_state: int = 64
+ mamba_d_conv: int = 4
+ mamba_num_heads: int = 64
+ mamba_step: int = 2
+ mamba_chunk_size: int = 256
+ mamba_enabled: bool = True
+ intermediate_size: int = 13312
+ vocab_size: int = 32000
+
+
+class RMSNorm(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ eps: float = 1e-6,
+ offset: float = 1.0,
+ ) -> None:
+ super().__init__()
+ self.weight = mx.zeros(hidden_size)
+ self.variance_epsilon = eps
+ self.offset = offset
+
+ def __call__(self, hidden_states: mx.array) -> mx.array:
+ return mx.fast.rms_norm(
+ hidden_states, self.weight + self.offset, self.variance_epsilon
+ )
+
+
+def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.astype(mx.float32)
+ variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
+ hidden_states = hidden_states * mx.rsqrt(variance + eps)
+ hidden_states = hidden_states.astype(input_dtype)
+
+ return hidden_states
+
+
+def get_initial_dt_bias(num_heads: int) -> mx.array:
+ dt_min = 0.001
+ dt_max = 0.1
+ dt = mx.exp(
+ mx.random.uniform(shape=(num_heads,)) * (math.log(dt_max) - math.log(dt_min))
+ + math.log(dt_min)
+ )
+ dt = mx.clip(dt, a_min=1e-4, a_max=None)
+ inv_dt = dt + mx.log(-mx.expm1(-dt))
+ return inv_dt
+
+
+def get_initial_A(num_heads: int) -> mx.array:
+ A = mx.arange(1, num_heads + 1, dtype=mx.float32)
+ return mx.log(A)
+
+
+# From: https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/ops/triton/selective_state_update.py#L219
+def selective_state_update_ref(
+ state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
+) -> tuple[mx.array, mx.array]:
+ """
+ Argument:
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
+ x: (batch, dim) or (batch, nheads, dim)
+ dt: (batch, dim) or (batch, nheads, dim)
+ A: (dim, dstate) or (nheads, dim, dstate)
+ B: (batch, dstate) or (batch, ngroups, dstate)
+ C: (batch, dstate) or (batch, ngroups, dstate)
+ D: (dim,) or (nheads, dim)
+ z: (batch, dim) or (batch, nheads, dim)
+ dt_bias: (dim,) or (nheads, dim)
+ Return:
+ out: (batch, dim) or (batch, nheads, dim)
+ """
+ has_heads = state.ndim > 3
+ if state.ndim == 3:
+ state = mx.expand_dims(state, 1)
+ if x.ndim == 2:
+ x = mx.expand_dims(x, 1)
+ if dt.ndim == 2:
+ dt = mx.expand_dims(dt, 1)
+ if A.ndim == 2:
+ A = mx.expand_dims(A, 0)
+ if B.ndim == 2:
+ B = mx.expand_dims(B, 1)
+ if C.ndim == 2:
+ C = mx.expand_dims(C, 1)
+ if D is not None and D.ndim == 1:
+ D = mx.expand_dims(D, 0)
+ if z is not None and z.ndim == 2:
+ z = mx.expand_dims(z, 1)
+ if dt_bias is not None and dt_bias.ndim == 1:
+ dt_bias = mx.expand_dims(dt_bias, 0)
+ batch, nheads, dim, dstate = state.shape
+ assert x.shape == (batch, nheads, dim)
+ assert dt.shape == x.shape
+ assert A.shape == (nheads, dim, dstate)
+ ngroups = B.shape[1]
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
+ assert B.shape == (batch, ngroups, dstate)
+ assert C.shape == B.shape
+ if D is not None:
+ assert D.shape == (nheads, dim)
+ if z is not None:
+ assert z.shape == x.shape
+ if dt_bias is not None:
+ assert dt_bias.shape == (nheads, dim)
+ dt = dt + dt_bias
+ dt = nn.softplus(dt) if dt_softplus else dt
+ dA = mx.exp(mx.expand_dims(dt, axis=-1) * A) # (batch, nheads, dim, dstate)
+ B = mx.reshape(
+ mx.repeat(mx.expand_dims(B, axis=2), nheads // ngroups, 2),
+ (batch, nheads, dstate),
+ ) # (batch, nheads, dstate)
+ C = mx.reshape(
+ mx.repeat(mx.expand_dims(C, axis=2), nheads // ngroups, 2),
+ (batch, nheads, dstate),
+ ) # (batch, nheads, dstate)
+ dB = mx.expand_dims(dt, axis=-1) * mx.expand_dims(
+ B, axis=-2
+ ) # (batch, nheads, dim, dstate)
+ state = state * dA + dB * mx.expand_dims(x, axis=-1) # (batch, dim, dstate)
+ out = mx.einsum("bhdn,bhn->bhd", state.astype(C.dtype), C)
+ if D is not None:
+ out += (x * D).astype(out.dtype)
+ out = (out if z is None else out * nn.silu(z)).astype(x.dtype)
+ if not has_heads:
+ out = out.squeeze(1)
+ return out, state
+
+
+def ssd_update_state(
+ ssm_state: mx.array,
+ x: mx.array,
+ dt: mx.array,
+ A: mx.array,
+ B: mx.array,
+ C: mx.array,
+ D: mx.array,
+ z: mx.array,
+ dt_bias: mx.array,
+ dt_softplus: bool,
+) -> tuple[mx.array, mx.array]:
+ assert ssm_state.dtype == mx.float32
+ dtype = x.dtype
+
+ hidden_size_per_head = x.shape[-1]
+ d_state = B.shape[-1]
+ A = mx.broadcast_to(
+ A[:, None, None], (A.shape[0], hidden_size_per_head, d_state)
+ ).astype(mx.float32)
+ dt = mx.broadcast_to(
+ dt[..., None], (dt.shape[0], dt.shape[1], hidden_size_per_head)
+ )
+ dt_bias = mx.broadcast_to(
+ dt_bias[:, None], (dt_bias.shape[0], hidden_size_per_head)
+ )
+ D = mx.broadcast_to(D[:, None], (D.shape[0], hidden_size_per_head))
+ out, ssm_state = selective_state_update_ref(
+ ssm_state,
+ x.astype(dtype),
+ dt.astype(dtype),
+ A.astype(mx.float32),
+ B.astype(dtype),
+ C.astype(dtype),
+ D.astype(mx.float32),
+ z.astype(dtype),
+ dt_bias.astype(mx.float32),
+ dt_softplus=dt_softplus,
+ )
+ return out[:, None], ssm_state
+
+
+def ssd_chunk_scan_combined(
+ x: mx.array,
+ dt: mx.array,
+ A: mx.array,
+ B: mx.array,
+ C: mx.array,
+ D: mx.array,
+ z: mx.array,
+ dt_bias: mx.array,
+ dt_softplus: bool,
+ ssm_state: mx.array,
+) -> tuple[mx.array, mx.array]:
+ assert ssm_state.dtype == mx.float32
+ length = x.shape[1]
+ ys = []
+ for i in range(length):
+ y, ssm_state = ssd_update_state(
+ ssm_state,
+ x[:, i],
+ dt[:, i],
+ A,
+ B[:, i],
+ C[:, i],
+ D if D.ndim == 1 else D[:, i],
+ z=z[:, i],
+ dt_bias=dt_bias,
+ dt_softplus=dt_softplus,
+ )
+ ys.append(y)
+ return mx.concatenate(ys, axis=1), ssm_state
+
+
+def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
+ _, seqlen, dim = x.shape
+ state_len = conv_state.shape[-2]
+ x = mx.concatenate([conv_state, x], axis=-2)
+ conv_state = x[:, -state_len:]
+ out = mx.conv1d(
+ x,
+ weight,
+ padding=0,
+ groups=dim,
+ )[:, -seqlen:]
+ return nn.silu(out), conv_state
+
+
+class Mamba(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.d_state = config.mamba_d_state
+ self.d_conv = config.mamba_d_conv
+ self.chunk_size = config.mamba_chunk_size
+ self.num_heads = config.mamba_num_heads
+ self.hidden_size_per_head = config.hidden_size_per_head
+
+ self.intermediate_size = self.num_heads * self.hidden_size_per_head
+
+ self.in_proj = nn.Linear(
+ self.hidden_size, 2 * self.intermediate_size, bias=False
+ )
+ self.conv1d = nn.Conv1d(
+ in_channels=self.intermediate_size,
+ out_channels=self.intermediate_size,
+ bias=False,
+ kernel_size=self.d_conv,
+ groups=self.intermediate_size,
+ padding=0,
+ )
+ self.dt_dim = max(64, self.hidden_size // 16)
+ self.bcdt_proj = nn.Linear(
+ self.intermediate_size,
+ self.dt_dim + 2 * self.d_state,
+ bias=False,
+ )
+ self.dt_proj = nn.Linear(self.dt_dim, self.num_heads, bias=False)
+
+ self.dt_bias = get_initial_dt_bias(self.num_heads)
+ self.A_log = get_initial_A(self.num_heads)
+ self.D = mx.ones(self.num_heads, dtype=mx.float32)
+
+ self.dt_norm_weight = mx.ones(self.dt_dim)
+ self.B_norm_weight = mx.ones(self.d_state)
+ self.C_norm_weight = mx.ones(self.d_state)
+
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+
+ def __call__(
+ self,
+ hidden_states: mx.array,
+ mask: Optional[mx.array] = None,
+ cache=None,
+ ):
+ bsize, length, _ = hidden_states.shape
+
+ if cache is not None and cache[0] is not None:
+ conv_state = cache[0]
+ ssm_state = cache[1]
+ else:
+ conv_state = mx.zeros(
+ (bsize, self.d_conv - 1, self.intermediate_size),
+ dtype=hidden_states.dtype,
+ )
+ ssm_state = mx.zeros(
+ (bsize, self.num_heads, self.hidden_size_per_head, self.d_state),
+ dtype=mx.float32,
+ )
+
+ zx = self.in_proj(hidden_states)
+ zx = zx.reshape(bsize, length, self.num_heads, -1)
+ # z: (bsize, length, num_heads, hidden_size_per_head)
+ # x: (bsize, length, num_heads, hidden_size_per_head)
+ z, x = mx.split(
+ zx,
+ [
+ self.hidden_size_per_head,
+ ],
+ axis=-1,
+ )
+
+ x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head)
+ x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight)
+ BCdt = self.bcdt_proj(x)
+ x = x.reshape(bsize, length, self.num_heads, -1)
+ B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
+
+ A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
+ dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps)
+ B = mx.fast.rms_norm(B, self.B_norm_weight, self.config.rms_norm_eps)
+ C = mx.fast.rms_norm(C, self.C_norm_weight, self.config.rms_norm_eps)
+
+ # (bsize, length, num_heads, 1)
+ dt = self.dt_proj(dt)[..., None]
+
+ out, ssm_state = ssd_chunk_scan_combined(
+ x,
+ dt.reshape(bsize, length, -1),
+ A,
+ B,
+ C,
+ D=self.D,
+ z=z,
+ dt_bias=self.dt_bias,
+ dt_softplus=True,
+ ssm_state=ssm_state,
+ )
+
+ if cache is not None:
+ cache[0] = conv_state
+ cache[1] = ssm_state
+ y = self.out_proj(out.reshape(bsize, length, -1))
+
+ return y
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ head_dim = config.hidden_size_per_head
+ self.max_position_embeddings = config.max_position_embeddings
+ self.scale = head_dim**-0.5
+
+ self.q_num_heads = config.num_attention_heads
+ self.qk_dim = self.v_dim = head_dim
+ self.k_num_heads = self.v_num_heads = config.num_key_value_heads
+ assert self.q_num_heads % self.k_num_heads == 0
+ self.n_group = self.q_num_heads // self.k_num_heads
+
+ self.q_proj_dim = self.q_num_heads * self.qk_dim
+ self.k_proj_dim = self.k_num_heads * self.qk_dim
+ self.v_proj_dim = self.k_num_heads * self.v_dim
+ self.qkv_proj = nn.Linear(
+ self.hidden_size,
+ self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
+ bias=False,
+ )
+ self.o_proj = nn.Linear(
+ self.q_num_heads * self.v_dim, self.hidden_size, bias=False
+ )
+
+ self.q_weight = mx.ones((self.q_num_heads, self.qk_dim))
+ self.k_weight = mx.ones((self.k_num_heads, self.qk_dim))
+
+ self.rope = nn.RoPE(self.qk_dim)
+
+ def __call__(
+ self,
+ hidden_states: mx.array,
+ mask: Optional[mx.array] = None,
+ cache=None,
+ ):
+ B, T, _ = hidden_states.shape
+
+ qkv = self.qkv_proj(hidden_states)
+ q, k, v = mx.split(
+ qkv, [self.q_proj_dim, self.q_proj_dim + self.k_proj_dim], axis=-1
+ )
+ q = q.reshape(B, T, self.q_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
+ k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
+ v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
+
+ q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
+ k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
+
+ if cache is not None:
+ q = self.rope(q, offset=cache.offset)
+ k = self.rope(k, offset=cache.offset)
+ k, v = cache.update_and_fetch(k, v)
+ else:
+ q = self.rope(q)
+ k = self.rope(k)
+
+ output = mx.fast.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ scale=self.scale,
+ mask=mask,
+ )
+ output = output.transpose(0, 2, 1, 3).reshape(
+ B, T, self.q_num_heads * self.v_dim
+ )
+ return self.o_proj(output)
+
+
+class MLP(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_up_proj = nn.Linear(
+ self.hidden_size, self.intermediate_size * 2, bias=False
+ )
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+
+ def __call__(self, x: mx.array) -> mx.array:
+ h = self.gate_up_proj(x)
+ hs = mx.split(h, 2, axis=-1)
+ return self.down_proj(nn.silu(hs[0]) * hs[1])
+
+
+class PlamoDecoderLayer(nn.Module):
+ def __init__(self, config: ModelArgs, is_mamba: bool) -> None:
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.is_mamba = is_mamba
+ self.mixer: nn.Module
+ if is_mamba:
+ self.mixer = Mamba(config)
+ else:
+ self.mixer = Attention(config)
+ self.mlp = MLP(config)
+ self.pre_mixer_norm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, offset=1.0
+ )
+ self.post_mixer_norm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5
+ )
+ self.pre_mlp_norm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, offset=1.0
+ )
+ self.post_mlp_norm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5)
+ )
+
+ def __call__(
+ self,
+ hidden_states: mx.array,
+ mask: Optional[mx.array] = None,
+ cache=None,
+ ):
+ residual = hidden_states
+ hidden_states = self.pre_mixer_norm(hidden_states)
+
+ hidden_states_sa = self.mixer(
+ hidden_states=hidden_states,
+ mask=mask,
+ cache=cache,
+ )
+
+ hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
+ hidden_states = residual + hidden_states_sa
+
+ residual = hidden_states
+ hidden_states = self.pre_mlp_norm(hidden_states)
+
+ # Fully Connected
+ hidden_states_mlp = self.mlp(hidden_states)
+
+ # Residual
+ hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
+ return residual + hidden_states_mlp
+
+
+def is_mamba(config: ModelArgs, i: int) -> bool:
+ if not config.mamba_enabled:
+ return False
+ assert config.mamba_step > 1
+ assert i < config.num_hidden_layers
+
+ if config.num_hidden_layers <= (config.mamba_step // 2):
+ # use attention in last layer
+ return i != config.num_hidden_layers - 1
+ return (i % config.mamba_step) != (config.mamba_step // 2)
+
+
+class PlamoDecoder(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+
+ self.layers = [
+ PlamoDecoderLayer(config, is_mamba=is_mamba(config, i))
+ for i in range(config.num_hidden_layers)
+ ]
+
+ def __call__(self, x: mx.array, mask: mx.array, cache):
+ for i, decoder_layer in enumerate(self.layers):
+ x = decoder_layer(
+ x,
+ mask=mask,
+ cache=cache[i],
+ )
+ return x
+
+
+class PlamoModel(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+
+ self.config = config
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = PlamoDecoder(config) # type: ignore
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def __call__(
+ self,
+ inputs: mx.array,
+ mask: Optional[mx.array] = None,
+ cache=None,
+ ):
+ batch_size, seq_length = inputs.shape
+
+ h = self.embed_tokens(inputs)
+
+ if mask is None:
+ mask = create_attention_mask(h, [cache[1]] if cache is not None else None)
+
+ if cache is None:
+ cache = [None] * len(self.layers.layers)
+
+ # decoder layers
+ out = self.layers(
+ h,
+ mask,
+ cache,
+ )
+
+ return self.norm(out)
+
+
+class Model(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+ self.model_type = config.model_type
+ self.model = PlamoModel(config)
+
+ self.vocab_size = config.vocab_size
+
+ if not config.tie_word_embeddings:
+ self.lm_head: nn.Module = nn.Linear(
+ config.hidden_size, self.vocab_size, bias=False
+ )
+
+ def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:
+ for k, v in weights.items():
+ if "conv1d.weight" in k and v.shape[-1] != 1:
+ weights[k] = v.moveaxis(2, 1)
+ return weights
+
+ def make_cache(self):
+ # TODO use RotatingKVCache is not full_attn
+ # full_attn = self.layer_idx in self.config.full_attention_idx
+ return [MambaCache() if l.is_mamba else KVCache() for l in self.layers]
+
+ def __call__(
+ self, inputs: mx.array, mask: Optional[mx.array] = None, cache=None
+ ) -> mx.array:
+ outputs = self.model(
+ inputs=inputs,
+ mask=None,
+ cache=cache,
+ )
+ if self.config.tie_word_embeddings:
+ logits = self.model.embed_tokens.as_linear(outputs)
+ else:
+ logits = self.lm_head(outputs)
+
+ return logits
+
+ @property
+ def layers(self):
+ return self.model.layers.layers
diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py
index 9c414afd..6340c77b 100644
--- a/llms/mlx_lm/models/su_rope.py
+++ b/llms/mlx_lm/models/su_rope.py
@@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module):
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)
+ self.dim = dims
def __call__(self, x, offset: int = 0):
+ x[..., : self.dim] = self.scale * x[..., : self.dim]
return mx.fast.rope(
- self.scale * x,
- x.shape[-1],
+ x,
+ self.dim,
traditional=False,
base=None,
scale=1.0,
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index f5df11e3..cc7c6c20 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -98,6 +98,7 @@ def linear_to_lora_layers(
"minicpm",
"deepseek",
"olmo2",
+ "olmoe",
"internlm3",
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
@@ -106,6 +107,8 @@ def linear_to_lora_layers(
if model.model_type == "qwen2_moe":
keys.add("mlp.gate")
keys.add("mlp.shared_expert_gate")
+ if model.model_type == "olmoe":
+ keys.add("mlp.gate")
elif model.model_type == "gpt_bigcode":
keys = set(["attn.c_attn"])
diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py
index 1fae76fa..05fac92f 100644
--- a/llms/mlx_lm/utils.py
+++ b/llms/mlx_lm/utils.py
@@ -191,7 +191,9 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
"*.py",
"tokenizer.model",
"*.tiktoken",
+ "tiktoken.model",
"*.txt",
+ "*.jsonl",
],
)
)
@@ -1014,6 +1016,46 @@ def save_config(
json.dump(config, fid, indent=4)
+def mixed_quant_predicate_builder(
+ low_bits: int = 4, high_bits: int = 4, group_size: int = 64
+) -> Callable[[str, nn.Module, dict], Union[bool, dict]]:
+ def mixed_quant_predicate(
+ path: str,
+ module: nn.Module,
+ config: dict,
+ ) -> Union[bool, dict]:
+ """Implements mixed quantization predicates with similar choices to, for example, llama.cpp's Q4_K_M.
+ Ref: https://github.com/ggerganov/llama.cpp/blob/917786f43d0f29b7c77a0c56767c0fa4df68b1c5/src/llama.cpp#L5265
+ By Alex Barron: https://gist.github.com/barronalex/84addb8078be21969f1690c1454855f3
+ """
+
+ if not hasattr(module, "to_quantized"):
+ return False
+
+ index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0
+
+ num_layers = config["num_hidden_layers"]
+ use_more_bits = (
+ index < num_layers // 8
+ or index >= 7 * num_layers // 8
+ or (index - num_layers // 8) % 3 == 2
+ )
+ if "v_proj" in path and use_more_bits:
+ return {"group_size": group_size, "bits": high_bits}
+ if "down_proj" in path and use_more_bits:
+ return {"group_size": group_size, "bits": high_bits}
+ if "lm_head" in path:
+ return {"group_size": group_size, "bits": high_bits}
+
+ return {"group_size": group_size, "bits": low_bits}
+
+ return mixed_quant_predicate
+
+
+mixed_3_6 = mixed_quant_predicate_builder(low_bits=3)
+mixed_2_6 = mixed_quant_predicate_builder(low_bits=2)
+
+
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py
index d8cf6820..0c0fc601 100644
--- a/llms/tests/test_models.py
+++ b/llms/tests/test_models.py
@@ -183,7 +183,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
- if model_type != "mamba":
+ if model_type not in ("mamba", "plamo2"):
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
outputs = model(inputs, mask=mask)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
@@ -372,6 +372,23 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
+ def test_plamo2(self):
+ from mlx_lm.models import plamo2
+
+ args = plamo2.ModelArgs(
+ model_type="plamo2",
+ hidden_size=1024,
+ num_hidden_layers=4,
+ intermediate_size=2048,
+ num_attention_heads=8,
+ rms_norm_eps=1e-5,
+ vocab_size=10_000,
+ )
+ model = plamo2.Model(args)
+ self.model_test_runner(
+ model, args.model_type, args.vocab_size, args.num_hidden_layers
+ )
+
def test_stablelm(self):
from mlx_lm.models import stablelm
diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py
index de5694d5..c1860892 100644
--- a/llms/tests/test_prompt_cache.py
+++ b/llms/tests/test_prompt_cache.py
@@ -298,7 +298,7 @@ class TestPromptCache(unittest.TestCase):
):
i += 1
self.assertEqual(tok, toks[i])
- self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
+ self.assertTrue(mx.allclose(logits, all_logits[i], rtol=3e-2))
if __name__ == "__main__":
diff --git a/lora/lora.py b/lora/lora.py
index 723e783d..6f91ccca 100644
--- a/lora/lora.py
+++ b/lora/lora.py
@@ -3,6 +3,7 @@
import argparse
import json
import math
+import sys
import time
from pathlib import Path
@@ -14,6 +15,9 @@ import utils as lora_utils
from mlx.utils import tree_flatten
from models import LoRALinear
+# Disable output buffering to see print statements in real-time
+sys.stdout.reconfigure(line_buffering=True)
+
def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")