Merge branch 'main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez 2025-03-10 19:41:27 +01:00 committed by GitHub
commit 64a0b0cddb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 626 additions and 106 deletions

View File

@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples. - Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models. - Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`. - Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `Mamba version 2` and support for `full-fine-tuning`. - Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `Mamba version 2`, `OLMoE` archtectures and support for `full-fine-tuning`.

View File

@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules. At the time of writing, `mlx` doesn't have built-in learning rate schedules.
We intend to update this example once these features are added. We intend to update this example once these features are added.
## Distributed training
The example also supports distributed data parallel training. You can launch a
distributed training as follows:
```shell
$ cat >hostfile.json
[
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
]
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
```

View File

@ -1,3 +1,4 @@
import mlx.core as mx
import numpy as np import numpy as np
from mlx.data.datasets import load_cifar10 from mlx.data.datasets import load_cifar10
@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None):
x = x.astype("float32") / 255.0 x = x.astype("float32") / 255.0
return (x - mean) / std return (x - mean) / std
group = mx.distributed.init()
tr_iter = ( tr_iter = (
tr.shuffle() tr.shuffle()
.partition_if(group.size() > 1, group.size(), group.rank())
.to_stream() .to_stream()
.image_random_h_flip("image", prob=0.5) .image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0) .pad("image", 0, 4, 4, 0.0)
@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None):
) )
test = load_cifar10(root=root, train=False) test = load_cifar10(root=root, train=False)
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size) test_iter = (
test.to_stream()
.partition_if(group.size() > 1, group.size(), group.rank())
.key_transform("image", normalize)
.batch(batch_size)
)
return tr_iter, test_iter return tr_iter, test_iter

View File

@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only") parser.add_argument("--cpu", action="store_true", help="use cpu only")
def print_zero(group, *args, **kwargs):
if group.rank() != 0:
return
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)
def eval_fn(model, inp, tgt): def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt) return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch):
acc = mx.mean(mx.argmax(output, axis=1) == tgt) acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc return loss, acc
losses = [] world = mx.distributed.init()
accs = [] losses = 0
samples_per_sec = [] accuracies = 0
samples_per_sec = 0
count = 0
def average_stats(stats, count):
if world.size() == 1:
return [s / count for s in stats]
with mx.stream(mx.cpu):
stats = mx.distributed.all_sum(mx.array(stats))
count = mx.distributed.all_sum(count)
return (stats / count).tolist()
state = [model.state, optimizer.state] state = [model.state, optimizer.state]
@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
def step(inp, tgt): def step(inp, tgt):
train_step_fn = nn.value_and_grad(model, train_step) train_step_fn = nn.value_and_grad(model, train_step)
(loss, acc), grads = train_step_fn(model, inp, tgt) (loss, acc), grads = train_step_fn(model, inp, tgt)
grads = nn.utils.average_gradients(grads)
optimizer.update(model, grads) optimizer.update(model, grads)
return loss, acc return loss, acc
@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch):
y = mx.array(batch["label"]) y = mx.array(batch["label"])
tic = time.perf_counter() tic = time.perf_counter()
loss, acc = step(x, y) loss, acc = step(x, y)
mx.eval(state) mx.eval(loss, acc, state)
toc = time.perf_counter() toc = time.perf_counter()
loss = loss.item() losses += loss.item()
acc = acc.item() accuracies += acc.item()
losses.append(loss) samples_per_sec += x.shape[0] / (toc - tic)
accs.append(acc) count += 1
throughput = x.shape[0] / (toc - tic)
samples_per_sec.append(throughput)
if batch_counter % 10 == 0: if batch_counter % 10 == 0:
print( l, a, s = average_stats(
[losses, accuracies, world.size() * samples_per_sec],
count,
)
print_zero(
world,
" | ".join( " | ".join(
( (
f"Epoch {epoch:02d} [{batch_counter:03d}]", f"Epoch {epoch:02d} [{batch_counter:03d}]",
f"Train loss {loss:.3f}", f"Train loss {l:.3f}",
f"Train acc {acc:.3f}", f"Train acc {a:.3f}",
f"Throughput: {throughput:.2f} images/second", f"Throughput: {s:.2f} images/second",
)
) )
),
) )
mean_tr_loss = mx.mean(mx.array(losses)) return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
def test_epoch(model, test_iter, epoch): def test_epoch(model, test_iter, epoch):
accs = [] accuracies = 0
count = 0
for batch_counter, batch in enumerate(test_iter): for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["image"]) x = mx.array(batch["image"])
y = mx.array(batch["label"]) y = mx.array(batch["label"])
acc = eval_fn(model, x, y) acc = eval_fn(model, x, y)
acc_value = acc.item() accuracies += acc.item()
accs.append(acc_value) count += 1
mean_acc = mx.mean(mx.array(accs))
return mean_acc with mx.stream(mx.cpu):
accuracies = mx.distributed.all_sum(accuracies)
count = mx.distributed.all_sum(count)
return (accuracies / count).item()
def main(args): def main(args):
mx.random.seed(args.seed) mx.random.seed(args.seed)
# Initialize the distributed group and report the nodes that showed up
world = mx.distributed.init()
if world.size() > 1:
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
model = getattr(resnet, args.arch)() model = getattr(resnet, args.arch)()
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")
optimizer = optim.Adam(learning_rate=args.lr) optimizer = optim.Adam(learning_rate=args.lr)
train_data, test_data = get_cifar10(args.batch_size) train_data, test_data = get_cifar10(args.batch_size)
for epoch in range(args.epochs): for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
print( print_zero(
world,
" | ".join( " | ".join(
( (
f"Epoch: {epoch}", f"Epoch: {epoch}",
f"avg. Train loss {tr_loss.item():.3f}", f"avg. Train loss {tr_loss:.3f}",
f"avg. Train acc {tr_acc.item():.3f}", f"avg. Train acc {tr_acc:.3f}",
f"Throughput: {throughput.item():.2f} images/sec", f"Throughput: {throughput:.2f} images/sec",
)
) )
),
) )
test_acc = test_epoch(model, test_data, epoch) test_acc = test_epoch(model, test_data, epoch)
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}") print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")
train_data.reset() train_data.reset()
test_data.reset() test_data.reset()

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.21.5" __version__ = "0.21.6"

View File

@ -11,7 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0 DEFAULT_SEED = None
DEFAULT_MAX_TOKENS = 256 DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@ -36,7 +36,12 @@ def setup_arg_parser():
parser.add_argument( parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
) )
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument(
"--seed",
type=int,
default=DEFAULT_SEED,
help="PRNG seed",
)
parser.add_argument( parser.add_argument(
"--max-kv-size", "--max-kv-size",
type=int, type=int,
@ -57,6 +62,7 @@ def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
if args.seed is not None:
mx.random.seed(args.seed) mx.random.seed(args.seed)
model, tokenizer = load( model, tokenizer = load(
@ -65,12 +71,25 @@ def main():
tokenizer_config={"trust_remote_code": True}, tokenizer_config={"trust_remote_code": True},
) )
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") def print_help():
print("The command list:")
print("- 'q' to exit")
print("- 'r' to reset the chat")
print("- 'h' to display these commands")
print(f"[INFO] Starting chat session with {args.model}.")
print_help()
prompt_cache = make_prompt_cache(model, args.max_kv_size) prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True: while True:
query = input(">> ") query = input(">> ")
if query == "q": if query == "q":
break break
if query == "r":
prompt_cache = make_prompt_cache(model, args.max_kv_size)
continue
if query == "h":
print_help()
continue
messages = [{"role": "user", "content": query}] messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate( for response in stream_generate(

View File

@ -2,8 +2,23 @@
import argparse import argparse
from . import utils
from .utils import convert from .utils import convert
QUANT_RECIPES = [
"mixed_2_6",
"mixed_3_6",
]
def quant_args(arg):
if arg not in QUANT_RECIPES:
raise argparse.ArgumentTypeError(
f"Invalid q-recipe {arg!r}. Choose from: {QUANT_RECIPES}"
)
else:
return getattr(utils, arg)
def configure_parser() -> argparse.ArgumentParser: def configure_parser() -> argparse.ArgumentParser:
""" """
@ -29,6 +44,12 @@ def configure_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4 "--q-bits", help="Bits per weight for quantization.", type=int, default=4
) )
parser.add_argument(
"--quant-predicate",
help=f"Mixed-bit quantization recipe. Choices: {QUANT_RECIPES}",
type=quant_args,
required=False,
)
parser.add_argument( parser.add_argument(
"--dtype", "--dtype",
help="Type to save the non-quantized parameters.", help="Type to save the non-quantized parameters.",

View File

@ -289,17 +289,15 @@ class MLXLM(LM):
contexts, options = zip(*[req.args for req in requests]) contexts, options = zip(*[req.args for req in requests])
# contrary to the doc the second element of the tuple contains # contrary to the doc the second element of the tuple contains
# {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0}
keys = list(options[0].keys())
assert "until" in keys
untils = [x["until"] for x in options]
completions = [] completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)): for context, opt in tqdm(zip(contexts, options), total=len(contexts)):
until = opt["until"]
context = self.tokenizer.encode( context = self.tokenizer.encode(
context, add_special_tokens=not self.use_chat_template context, add_special_tokens=not self.use_chat_template
) )
max_tokens = min( max_tokens = min(
self._max_tokens, opt.get("max_gen_tokens", self._max_tokens),
self.tokenizer.model_max_length - len(context), self.tokenizer.model_max_length - len(context),
) )
text = "" text = ""
@ -334,9 +332,9 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--limit", "--limit",
default=1.0, default=100,
help="Limit the number of examples per task.", help="Limit the number of examples per task.",
type=float, type=int,
) )
parser.add_argument("--seed", type=int, default=123, help="Random seed.") parser.add_argument("--seed", type=int, default=123, help="Random seed.")
parser.add_argument( parser.add_argument(

View File

@ -7,6 +7,15 @@ train: true
# The fine-tuning method: "lora", "dora", or "full". # The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora fine_tune_type: lora
# The Optimizer with its possible inputs
optimizer: adamw
# optimizer_config:
# adamw:
# betas: [0.9, 0.98]
# eps: 1e-6
# weight_decay: 0.05
# bias_correction: true
# Directory with {train, valid, test}.jsonl files # Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data" data: "/path/to/training/data"

View File

@ -0,0 +1,73 @@
# Copyright © 2025 Apple Inc.
import json
from mlx_lm import generate, load
from mlx_lm.models.cache import make_prompt_cache
# Specify the checkpoint
checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit"
# Load the corresponding model and tokenizer
model, tokenizer = load(path_or_hf_repo=checkpoint)
# An example tool, make sure to include a docstring and type hints
def multiply(a: float, b: float):
"""
A function that multiplies two numbers
Args:
a: The first number to multiply
b: The second number to multiply
"""
return a * b
tools = {"multiply": multiply}
# Specify the prompt and conversation history
prompt = "Multiply 12234585 and 48838483920."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tools=list(tools.values())
)
prompt_cache = make_prompt_cache(model)
# Generate the initial tool call:
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=2048,
verbose=True,
prompt_cache=prompt_cache,
)
# Parse the tool call:
# (Note, the tool call format is model specific)
tool_open = "<tool_call>"
tool_close = "</tool_call>"
start_tool = response.find(tool_open) + len(tool_open)
end_tool = response.find(tool_close)
tool_call = json.loads(response[start_tool:end_tool].strip())
tool_result = tools[tool_call["name"]](**tool_call["arguments"])
# Put the tool result in the prompt
messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}]
prompt = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
)
# Generate the final response:
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=2048,
verbose=True,
prompt_cache=prompt_cache,
)

View File

@ -16,7 +16,7 @@ DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0 DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1 DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0 DEFAULT_SEED = None
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000 DEFAULT_QUANTIZED_KV_START = 5000
@ -60,6 +60,11 @@ def setup_arg_parser():
default=DEFAULT_PROMPT, default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
parser.add_argument(
"--prefill-response",
default=None,
help="Prefill response to be used for the chat template",
)
parser.add_argument( parser.add_argument(
"--max-tokens", "--max-tokens",
"-m", "-m",
@ -82,7 +87,12 @@ def setup_arg_parser():
default=DEFAULT_MIN_TOKENS_TO_KEEP, default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.", help="Minimum tokens to keep for min-p sampling.",
) )
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument(
"--seed",
type=int,
default=DEFAULT_SEED,
help="PRNG seed",
)
parser.add_argument( parser.add_argument(
"--ignore-chat-template", "--ignore-chat-template",
action="store_true", action="store_true",
@ -147,7 +157,7 @@ def setup_arg_parser():
"--num-draft-tokens", "--num-draft-tokens",
type=int, type=int,
help="Number of tokens to draft when using speculative decoding.", help="Number of tokens to draft when using speculative decoding.",
default=2, default=3,
) )
return parser return parser
@ -155,6 +165,8 @@ def setup_arg_parser():
def main(): def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
if args.seed is not None:
mx.random.seed(args.seed) mx.random.seed(args.seed)
# Load the prompt cache and metadata if a cache file is provided # Load the prompt cache and metadata if a cache file is provided
@ -219,10 +231,14 @@ def main():
messages = [] messages = []
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
has_prefill = args.prefill_response is not None
if has_prefill:
messages.append({"role": "assistant", "content": args.prefill_response})
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
**template_kwargs, **template_kwargs,
) )
@ -233,7 +249,8 @@ def main():
test_prompt = tokenizer.apply_chat_template( test_prompt = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=False,
add_generation_prompt=True, continue_final_message=has_prefill,
add_generation_prompt=not has_prefill,
) )
prompt = prompt[test_prompt.index("<query>") :] prompt = prompt[test_prompt.index("<query>") :]
prompt = tokenizer.encode(prompt, add_special_tokens=False) prompt = tokenizer.encode(prompt, add_special_tokens=False)

View File

@ -43,6 +43,11 @@ CONFIG_DEFAULTS = {
"model": "mlx_model", "model": "mlx_model",
"train": False, "train": False,
"fine_tune_type": "lora", "fine_tune_type": "lora",
"optimizer": "adam",
"optimizer_config": {
"adam": {},
"adamw": {},
},
"data": "data/", "data": "data/",
"seed": 0, "seed": 0,
"num_layers": 16, "num_layers": 16,
@ -62,6 +67,7 @@ CONFIG_DEFAULTS = {
"grad_checkpoint": False, "grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"mask_prompt": False,
} }
@ -94,14 +100,19 @@ def build_parser():
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument(
"--optimizer",
type=str,
choices=["adam", "adamw"],
default=None,
help="Optimizer to use for training: adam or adamw",
)
parser.add_argument( parser.add_argument(
"--mask-prompt", "--mask-prompt",
action="store_true", action="store_true",
help="Mask the prompt in the loss when training", help="Mask the prompt in the loss when training",
default=False, default=None,
) )
parser.add_argument( parser.add_argument(
"--num-layers", "--num-layers",
type=int, type=int,
@ -228,11 +239,21 @@ def train_model(
) )
model.train() model.train()
opt = optim.Adam(
learning_rate=( # Initialize the selected optimizer
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
) optimizer_name = args.optimizer.lower()
optimizer_config = args.optimizer_config.get(optimizer_name, {})
if optimizer_name == "adam":
opt_class = optim.Adam
elif optimizer_name == "adamw":
opt_class = optim.AdamW
else:
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
opt = opt_class(learning_rate=lr, **optimizer_config)
# Train model # Train model
train( train(

View File

@ -2,7 +2,22 @@ import argparse
from typing import List, Union from typing import List, Union
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from transformers.commands.user import tabulate
def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)
def ask_for_confirmation(message: str) -> bool: def ask_for_confirmation(message: str) -> bool:

View File

@ -33,13 +33,13 @@ def create_causal_mask(
linds = mx.arange(offset, offset + N) if offset else rinds linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None] linds = linds[:, None]
rinds = rinds[None] rinds = rinds[None]
mask = linds < rinds mask = linds >= rinds
if window_size is not None: if window_size is not None:
mask = mask | (linds > rinds + window_size) mask = mask & (linds <= rinds + window_size)
if lengths is not None: if lengths is not None:
lengths = lengths[:, None, None, None] lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths) mask = mask & (rinds < lengths)
return mask * -1e9 return mask
def create_attention_mask(h: mx.array, cache: Optional[Any] = None): def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
else: else:
offset = c.offset offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size) mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else: else:
mask = None mask = None
return mask return mask

View File

@ -181,6 +181,7 @@ class DeepseekV3Attention(nn.Module):
bias=config.attention_bias, bias=config.attention_bias,
) )
if self.config.rope_scaling is not None:
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"] scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim: if mscale_all_dim:
@ -205,6 +206,12 @@ class DeepseekV3Attention(nn.Module):
base=self.rope_theta, base=self.rope_theta,
**rope_kwargs, **rope_kwargs,
) )
else:
self.rope = nn.RoPE(
dims=self.qk_rope_head_dim,
base=self.rope_theta,
traditional=True,
)
def __call__( def __call__(
self, self,
@ -487,8 +494,12 @@ class Model(nn.Module):
] ]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
# Remove multi-token prediction layer # Remove multi-token prediction layer and any unused precomputed rotary freqs
return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")} return {
k: v
for k, v in weights.items()
if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k
}
@property @property
def layers(self): def layers(self):

View File

@ -196,9 +196,12 @@ class Model(nn.Module):
def sanitize(self, weights): def sanitize(self, weights):
# Remove unused precomputed rotary freqs # Remove unused precomputed rotary freqs
return { weights = {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
} }
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
return weights
@property @property
def layers(self): def layers(self):

217
llms/mlx_lm/models/olmoe.py Normal file
View File

@ -0,0 +1,217 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_experts: int
num_experts_per_tok: int
norm_topk_prob: bool = False
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class OlmoeSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_experts = args.num_experts
self.top_k = args.num_experts_per_tok
self.norm_topk_prob = args.norm_topk_prob
self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(
args.hidden_size,
args.intermediate_size,
self.num_experts,
bias=args.mlp_bias,
)
def __call__(self, x: mx.array) -> mx.array:
B, L, D = x.shape
x_flat = x.reshape(-1, D)
router_logits = self.gate(x_flat)
routing_weights = mx.softmax(router_logits, axis=1, precise=True)
k = self.top_k
indices = mx.stop_gradient(
mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k]
)
scores = mx.take_along_axis(routing_weights, indices, axis=-1)
if self.norm_topk_prob:
scores = scores / scores.sum(axis=-1, keepdims=True)
y = self.switch_mlp(x_flat, indices)
y = (y * scores[..., None]).sum(axis=-2)
return y.reshape(B, L, D)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.self_attn = Attention(args)
self.mlp = OlmoeSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
x = x + self.self_attn(self.input_layernorm(x), mask, cache)
x = x + self.mlp(self.post_attention_layernorm(x))
return x
class OlmoeModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
mask=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = OlmoeModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
mask=None,
):
out = self.model(inputs, cache, mask)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 10000 rope_theta: float = 10000
rope_traditional: bool = False rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
partial_rotary_factor: float = 1.0
max_position_embeddings: int = 131072 max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096 original_max_position_embeddings: int = 4096
tie_word_embeddings: bool = False
def __post_init__(self): def __post_init__(self):
if self.num_key_value_heads is None: if self.num_key_value_heads is None:
@ -59,9 +61,10 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_dim = int(head_dim * args.partial_rotary_factor)
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding( self.rope = SuScaledRotaryEmbedding(
head_dim, rope_dim,
base=args.rope_theta, base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings, max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings, original_max_position_embeddings=args.original_max_position_embeddings,
@ -74,7 +77,7 @@ class Attention(nn.Module):
assert isinstance(args.rope_scaling["factor"], float) assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"] rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE( self.rope = nn.RoPE(
head_dim, rope_dim,
traditional=args.rope_traditional, traditional=args.rope_traditional,
base=args.rope_theta, base=args.rope_theta,
scale=rope_scale, scale=rope_scale,
@ -190,6 +193,7 @@ class Model(nn.Module):
super().__init__() super().__init__()
self.model_type = args.model_type self.model_type = args.model_type
self.model = Phi3Model(args) self.model = Phi3Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args self.args = args
@ -200,7 +204,11 @@ class Model(nn.Module):
cache=None, cache=None,
): ):
out = self.model(inputs, mask, cache) out = self.model(inputs, mask, cache)
return self.lm_head(out) if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property @property
def layers(self): def layers(self):

View File

@ -2,7 +2,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Union from typing import Any, Optional
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -32,7 +32,6 @@ class ModelArgs(BaseModelArgs):
mamba_enabled: bool = True mamba_enabled: bool = True
intermediate_size: int = 13312 intermediate_size: int = 13312
vocab_size: int = 32000 vocab_size: int = 32000
max_position_embeddings: int = 10 * 1024 * 1024
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -53,6 +52,16 @@ class RMSNorm(nn.Module):
) )
def _rms_norm(hidden_states: mx.array, eps: float) -> mx.array:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(mx.float32)
variance = mx.power(hidden_states, 2).mean(-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + eps)
hidden_states = hidden_states.astype(input_dtype)
return hidden_states
def get_initial_dt_bias(num_heads: int) -> mx.array: def get_initial_dt_bias(num_heads: int) -> mx.array:
dt_min = 0.001 dt_min = 0.001
dt_max = 0.1 dt_max = 0.1
@ -220,8 +229,7 @@ def ssd_chunk_scan_combined(
def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]: def causal_conv1d_update(conv_state, x, weight) -> tuple[mx.array, mx.array]:
batch, seqlen, dim = x.shape _, seqlen, dim = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-2] state_len = conv_state.shape[-2]
x = mx.concatenate([conv_state, x], axis=-2) x = mx.concatenate([conv_state, x], axis=-2)
conv_state = x[:, -state_len:] conv_state = x[:, -state_len:]
@ -392,8 +400,8 @@ class Attention(nn.Module):
k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3) k = k.reshape(B, T, self.k_num_heads, self.qk_dim).transpose(0, 2, 1, 3)
v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.v_num_heads, self.v_dim).transpose(0, 2, 1, 3)
q = mx.fast.layer_norm(q, None, None, 1e-6) * self.q_weight[:, None] q = _rms_norm(q, 1e-6) * self.q_weight[:, None]
k = mx.fast.layer_norm(k, None, None, 1e-6) * self.k_weight[:, None] k = _rms_norm(k, 1e-6) * self.k_weight[:, None]
if cache is not None: if cache is not None:
q = self.rope(q, offset=cache.offset) q = self.rope(q, offset=cache.offset)
@ -556,7 +564,6 @@ class PlamoModel(nn.Module):
class Model(nn.Module): class Model(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
@ -567,7 +574,7 @@ class Model(nn.Module):
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
self.lm_head: nn.Module = nn.Linear( self.lm_head: nn.Module = nn.Linear(
config.hidden_size, vocab_size, bias=False config.hidden_size, self.vocab_size, bias=False
) )
def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]: def sanitize(self, weights: dict[Any, Any]) -> dict[Any, Any]:

View File

@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module):
+ math.log(max_position_embeddings / original_max_position_embeddings) + math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings) / math.log(original_max_position_embeddings)
) )
self.dim = dims
def __call__(self, x, offset: int = 0): def __call__(self, x, offset: int = 0):
x[..., : self.dim] = self.scale * x[..., : self.dim]
return mx.fast.rope( return mx.fast.rope(
self.scale * x, x,
x.shape[-1], self.dim,
traditional=False, traditional=False,
base=None, base=None,
scale=1.0, scale=1.0,

View File

@ -98,6 +98,7 @@ def linear_to_lora_layers(
"minicpm", "minicpm",
"deepseek", "deepseek",
"olmo2", "olmo2",
"olmoe",
"internlm3", "internlm3",
]: ]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"]) keys = set(["self_attn.q_proj", "self_attn.v_proj"])
@ -106,6 +107,8 @@ def linear_to_lora_layers(
if model.model_type == "qwen2_moe": if model.model_type == "qwen2_moe":
keys.add("mlp.gate") keys.add("mlp.gate")
keys.add("mlp.shared_expert_gate") keys.add("mlp.shared_expert_gate")
if model.model_type == "olmoe":
keys.add("mlp.gate")
elif model.model_type == "gpt_bigcode": elif model.model_type == "gpt_bigcode":
keys = set(["attn.c_attn"]) keys = set(["attn.c_attn"])

View File

@ -191,6 +191,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
"*.py", "*.py",
"tokenizer.model", "tokenizer.model",
"*.tiktoken", "*.tiktoken",
"tiktoken.model",
"*.txt", "*.txt",
"*.jsonl", "*.jsonl",
], ],
@ -1015,6 +1016,46 @@ def save_config(
json.dump(config, fid, indent=4) json.dump(config, fid, indent=4)
def mixed_quant_predicate_builder(
low_bits: int = 4, high_bits: int = 4, group_size: int = 64
) -> Callable[[str, nn.Module, dict], Union[bool, dict]]:
def mixed_quant_predicate(
path: str,
module: nn.Module,
config: dict,
) -> Union[bool, dict]:
"""Implements mixed quantization predicates with similar choices to, for example, llama.cpp's Q4_K_M.
Ref: https://github.com/ggerganov/llama.cpp/blob/917786f43d0f29b7c77a0c56767c0fa4df68b1c5/src/llama.cpp#L5265
By Alex Barron: https://gist.github.com/barronalex/84addb8078be21969f1690c1454855f3
"""
if not hasattr(module, "to_quantized"):
return False
index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0
num_layers = config["num_hidden_layers"]
use_more_bits = (
index < num_layers // 8
or index >= 7 * num_layers // 8
or (index - num_layers // 8) % 3 == 2
)
if "v_proj" in path and use_more_bits:
return {"group_size": group_size, "bits": high_bits}
if "down_proj" in path and use_more_bits:
return {"group_size": group_size, "bits": high_bits}
if "lm_head" in path:
return {"group_size": group_size, "bits": high_bits}
return {"group_size": group_size, "bits": low_bits}
return mixed_quant_predicate
mixed_3_6 = mixed_quant_predicate_builder(low_bits=3)
mixed_2_6 = mixed_quant_predicate_builder(low_bits=2)
def convert( def convert(
hf_path: str, hf_path: str,
mlx_path: str = "mlx_model", mlx_path: str = "mlx_model",

View File

@ -298,7 +298,7 @@ class TestPromptCache(unittest.TestCase):
): ):
i += 1 i += 1
self.assertEqual(tok, toks[i]) self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) self.assertTrue(mx.allclose(logits, all_logits[i], rtol=3e-2))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,6 +3,7 @@
import argparse import argparse
import json import json
import math import math
import sys
import time import time
from pathlib import Path from pathlib import Path
@ -14,6 +15,9 @@ import utils as lora_utils
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from models import LoRALinear from models import LoRALinear
# Disable output buffering to see print statements in real-time
sys.stdout.reconfigure(line_buffering=True)
def build_parser(): def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")