Merge branch 'ml-explore:main' into adding-reporting-to-wandb

This commit is contained in:
Gökdeniz Gülmez 2025-03-12 14:35:39 +01:00 committed by GitHub
commit 0e28fdb345
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 561 additions and 139 deletions

View File

@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`.

View File

@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
We intend to update this example once these features are added.
## Distributed training
The example also supports distributed data parallel training. You can launch a
distributed training as follows:
```shell
$ cat >hostfile.json
[
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
]
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
```

View File

@ -1,3 +1,4 @@
import mlx.core as mx
import numpy as np
from mlx.data.datasets import load_cifar10
@ -12,8 +13,11 @@ def get_cifar10(batch_size, root=None):
x = x.astype("float32") / 255.0
return (x - mean) / std
group = mx.distributed.init()
tr_iter = (
tr.shuffle()
.partition_if(group.size() > 1, group.size(), group.rank())
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
@ -25,6 +29,11 @@ def get_cifar10(batch_size, root=None):
)
test = load_cifar10(root=root, train=False)
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
test_iter = (
test.to_stream()
.partition_if(group.size() > 1, group.size(), group.rank())
.key_transform("image", normalize)
.batch(batch_size)
)
return tr_iter, test_iter

View File

@ -23,6 +23,13 @@ parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")
def print_zero(group, *args, **kwargs):
if group.rank() != 0:
return
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
@ -34,9 +41,20 @@ def train_epoch(model, train_iter, optimizer, epoch):
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc
losses = []
accs = []
samples_per_sec = []
world = mx.distributed.init()
losses = 0
accuracies = 0
samples_per_sec = 0
count = 0
def average_stats(stats, count):
if world.size() == 1:
return [s / count for s in stats]
with mx.stream(mx.cpu):
stats = mx.distributed.all_sum(mx.array(stats))
count = mx.distributed.all_sum(count)
return (stats / count).tolist()
state = [model.state, optimizer.state]
@ -44,6 +62,7 @@ def train_epoch(model, train_iter, optimizer, epoch):
def step(inp, tgt):
train_step_fn = nn.value_and_grad(model, train_step)
(loss, acc), grads = train_step_fn(model, inp, tgt)
grads = nn.utils.average_gradients(grads)
optimizer.update(model, grads)
return loss, acc
@ -52,69 +71,79 @@ def train_epoch(model, train_iter, optimizer, epoch):
y = mx.array(batch["label"])
tic = time.perf_counter()
loss, acc = step(x, y)
mx.eval(state)
mx.eval(loss, acc, state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()
losses.append(loss)
accs.append(acc)
throughput = x.shape[0] / (toc - tic)
samples_per_sec.append(throughput)
losses += loss.item()
accuracies += acc.item()
samples_per_sec += x.shape[0] / (toc - tic)
count += 1
if batch_counter % 10 == 0:
print(
l, a, s = average_stats(
[losses, accuracies, world.size() * samples_per_sec],
count,
)
print_zero(
world,
" | ".join(
(
f"Epoch {epoch:02d} [{batch_counter:03d}]",
f"Train loss {loss:.3f}",
f"Train acc {acc:.3f}",
f"Throughput: {throughput:.2f} images/second",
f"Train loss {l:.3f}",
f"Train acc {a:.3f}",
f"Throughput: {s:.2f} images/second",
)
)
),
)
mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
def test_epoch(model, test_iter, epoch):
accs = []
accuracies = 0
count = 0
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
acc = eval_fn(model, x, y)
acc_value = acc.item()
accs.append(acc_value)
mean_acc = mx.mean(mx.array(accs))
return mean_acc
accuracies += acc.item()
count += 1
with mx.stream(mx.cpu):
accuracies = mx.distributed.all_sum(accuracies)
count = mx.distributed.all_sum(count)
return (accuracies / count).item()
def main(args):
mx.random.seed(args.seed)
# Initialize the distributed group and report the nodes that showed up
world = mx.distributed.init()
if world.size() > 1:
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
model = getattr(resnet, args.arch)()
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")
optimizer = optim.Adam(learning_rate=args.lr)
train_data, test_data = get_cifar10(args.batch_size)
for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
print(
print_zero(
world,
" | ".join(
(
f"Epoch: {epoch}",
f"avg. Train loss {tr_loss.item():.3f}",
f"avg. Train acc {tr_acc.item():.3f}",
f"Throughput: {throughput.item():.2f} images/sec",
f"avg. Train loss {tr_loss:.3f}",
f"avg. Train acc {tr_acc:.3f}",
f"Throughput: {throughput:.2f} images/sec",
)
)
),
)
test_acc = test_epoch(model, test_data, epoch)
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")
train_data.reset()
test_data.reset()

View File

@ -11,7 +11,7 @@ from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_SEED = None
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
@ -36,7 +36,12 @@ def setup_arg_parser():
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--seed",
type=int,
default=DEFAULT_SEED,
help="PRNG seed",
)
parser.add_argument(
"--max-kv-size",
type=int,
@ -57,7 +62,8 @@ def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
if args.seed is not None:
mx.random.seed(args.seed)
model, tokenizer = load(
args.model,

View File

@ -1,27 +1,23 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from enum import Enum
from .utils import convert, mixed_2_6, mixed_3_6
from . import utils
from .utils import convert
class MixedQuants(Enum):
mixed_3_6 = "mixed_3_6"
mixed_2_6 = "mixed_2_6"
@classmethod
def recipe_names(cls):
return [member.name for member in cls]
QUANT_RECIPES = [
"mixed_2_6",
"mixed_3_6",
]
def quant_args(arg):
try:
return MixedQuants[arg].value
except KeyError:
if arg not in QUANT_RECIPES:
raise argparse.ArgumentTypeError(
f"Invalid q-recipe {arg!r}. Choose from: {MixedQuants.recipe_names()}"
f"Invalid q-recipe {arg!r}. Choose from: {QUANT_RECIPES}"
)
else:
return getattr(utils, arg)
def configure_parser() -> argparse.ArgumentParser:
@ -50,7 +46,7 @@ def configure_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--quant-predicate",
help=f"Mixed-bit quantization recipe. Choices: {MixedQuants.recipe_names()}",
help=f"Mixed-bit quantization recipe. Choices: {QUANT_RECIPES}",
type=quant_args,
required=False,
)

View File

@ -7,6 +7,15 @@ train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
# The Optimizer with its possible inputs
optimizer: adamw
# optimizer_config:
# adamw:
# betas: [0.9, 0.98]
# eps: 1e-6
# weight_decay: 0.05
# bias_correction: true
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"

View File

@ -0,0 +1,73 @@
# Copyright © 2025 Apple Inc.
import json
from mlx_lm import generate, load
from mlx_lm.models.cache import make_prompt_cache
# Specify the checkpoint
checkpoint = "mlx-community/Qwen2.5-32B-Instruct-4bit"
# Load the corresponding model and tokenizer
model, tokenizer = load(path_or_hf_repo=checkpoint)
# An example tool, make sure to include a docstring and type hints
def multiply(a: float, b: float):
"""
A function that multiplies two numbers
Args:
a: The first number to multiply
b: The second number to multiply
"""
return a * b
tools = {"multiply": multiply}
# Specify the prompt and conversation history
prompt = "Multiply 12234585 and 48838483920."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tools=list(tools.values())
)
prompt_cache = make_prompt_cache(model)
# Generate the initial tool call:
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=2048,
verbose=True,
prompt_cache=prompt_cache,
)
# Parse the tool call:
# (Note, the tool call format is model specific)
tool_open = "<tool_call>"
tool_close = "</tool_call>"
start_tool = response.find(tool_open) + len(tool_open)
end_tool = response.find(tool_close)
tool_call = json.loads(response[start_tool:end_tool].strip())
tool_result = tools[tool_call["name"]](**tool_call["arguments"])
# Put the tool result in the prompt
messages = [{"role": "tool", "name": tool_call["name"], "content": tool_result}]
prompt = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
)
# Generate the final response:
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=2048,
verbose=True,
prompt_cache=prompt_cache,
)

View File

@ -16,7 +16,7 @@ DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0
DEFAULT_SEED = None
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
@ -87,7 +87,12 @@ def setup_arg_parser():
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--seed",
type=int,
default=DEFAULT_SEED,
help="PRNG seed",
)
parser.add_argument(
"--ignore-chat-template",
action="store_true",
@ -152,7 +157,7 @@ def setup_arg_parser():
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
default=2,
default=3,
)
return parser
@ -160,7 +165,9 @@ def setup_arg_parser():
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
if args.seed is not None:
mx.random.seed(args.seed)
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None

View File

@ -43,6 +43,11 @@ CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"optimizer": "adam",
"optimizer_config": {
"adam": {},
"adamw": {},
},
"data": "data/",
"seed": 0,
"num_layers": 16,
@ -96,14 +101,19 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--optimizer",
type=str,
choices=["adam", "adamw"],
default=None,
help="Optimizer to use for training: adam or adamw",
)
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=None,
)
parser.add_argument(
"--num-layers",
type=int,
@ -236,11 +246,21 @@ def train_model(
)
model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Initialize the selected optimizer
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
optimizer_name = args.optimizer.lower()
optimizer_config = args.optimizer_config.get(optimizer_name, {})
if optimizer_name == "adam":
opt_class = optim.Adam
elif optimizer_name == "adamw":
opt_class = optim.AdamW
else:
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
opt = opt_class(learning_rate=lr, **optimizer_config)
# Train model
train(

View File

@ -33,13 +33,13 @@ def create_causal_mask(
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
mask = linds >= rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
mask = mask & (linds <= rinds + window_size)
if lengths is not None:
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9
mask = mask & (rinds < lengths)
return mask
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
@ -55,7 +55,6 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
else:
offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None
return mask

View File

@ -196,9 +196,12 @@ class Model(nn.Module):
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
weights = {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
return weights
@property
def layers(self):

217
llms/mlx_lm/models/olmoe.py Normal file
View File

@ -0,0 +1,217 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_experts: int
num_experts_per_tok: int
norm_topk_prob: bool = False
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class OlmoeSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_experts = args.num_experts
self.top_k = args.num_experts_per_tok
self.norm_topk_prob = args.norm_topk_prob
self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(
args.hidden_size,
args.intermediate_size,
self.num_experts,
bias=args.mlp_bias,
)
def __call__(self, x: mx.array) -> mx.array:
B, L, D = x.shape
x_flat = x.reshape(-1, D)
router_logits = self.gate(x_flat)
routing_weights = mx.softmax(router_logits, axis=1, precise=True)
k = self.top_k
indices = mx.stop_gradient(
mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k]
)
scores = mx.take_along_axis(routing_weights, indices, axis=-1)
if self.norm_topk_prob:
scores = scores / scores.sum(axis=-1, keepdims=True)
y = self.switch_mlp(x_flat, indices)
y = (y * scores[..., None]).sum(axis=-2)
return y.reshape(B, L, D)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.self_attn = Attention(args)
self.mlp = OlmoeSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
x = x + self.self_attn(self.input_layernorm(x), mask, cache)
x = x + self.mlp(self.post_attention_layernorm(x))
return x
class OlmoeModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
mask=None,
):
h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = OlmoeModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
mask=None,
):
out = self.model(inputs, cache, mask)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@ -35,14 +35,25 @@ def make_sampler(
"""
if temp == 0:
return lambda x: mx.argmax(x, axis=-1)
elif top_p > 0 and top_p < 1.0:
return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
elif top_k > 0:
return lambda x: top_k_sampling(x, top_k, temp)
else:
return lambda x: categorical_sampling(x, temp)
# Create sampler chain
sampling_methods = []
if top_k > 0:
sampling_methods.append(lambda x: apply_top_k(x, top_k))
if top_p > 0 and top_p < 1.0:
sampling_methods.append(lambda x: apply_top_p(x, top_p))
if min_p != 0.0:
sampling_methods.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep))
# Apply the sampling methods
def sampler(logits):
for method in sampling_methods:
logits = method(logits)
# Return the sampled token
return categorical_sampling(logits, temp)
return sampler
def make_logits_processors(
@ -85,10 +96,9 @@ def make_logits_processors(
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_k_sampling(
def apply_top_k(
logprobs: mx.array,
top_k: int,
temperature=1.0,
) -> mx.array:
"""
Sample from only the top K tokens ranked by probability.
@ -103,20 +113,18 @@ def top_k_sampling(
f"`top_k` has to be an integer in the (0, {vocab_size}] interval,"
f" but is {top_k}."
)
logprobs = logprobs * (1 / temperature)
mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:]
masked_logprobs = mx.put_along_axis(
logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1
)
return mx.random.categorical(masked_logprobs, axis=-1)
return masked_logprobs
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
def apply_min_p(
logprobs: mx.array,
min_p: float,
min_tokens_to_keep: int = 1,
temperature=1.0,
) -> mx.array:
"""
Apply min-p sampling to the logprobs.
@ -144,8 +152,6 @@ def min_p_sampling(
)
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
logprobs = logprobs * (1 / temperature)
# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs, axis=-1)
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
@ -163,25 +169,31 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)
# Return sampled tokens
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
# Create a mapping to rearrange back to original indices
# Use argsort of sorted_indices to get the inverse permutation
inverse_indices = mx.argsort(sorted_indices, axis=-1)
# Rearrange selected_logprobs back to original order
original_order_logprobs = mx.take_along_axis(
selected_logprobs, inverse_indices, axis=-1
)
return original_order_logprobs
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
def apply_top_p(logits: mx.array, top_p: float) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.
Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=-1)
probs = mx.softmax(logits, axis=-1)
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
@ -196,8 +208,15 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
0,
)
sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
# Create a mapping to rearrange back to original indices
# Use argsort of sorted_indices to get the inverse permutation
inverse_indices = mx.argsort(sorted_indices, axis=-1)
# Rearrange top_probs back to original order
original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1)
# Convert back to logits and return
return mx.log(original_order_probs)
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)

View File

@ -98,6 +98,7 @@ def linear_to_lora_layers(
"minicpm",
"deepseek",
"olmo2",
"olmoe",
"internlm3",
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
@ -106,6 +107,8 @@ def linear_to_lora_layers(
if model.model_type == "qwen2_moe":
keys.add("mlp.gate")
keys.add("mlp.shared_expert_gate")
if model.model_type == "olmoe":
keys.add("mlp.gate")
elif model.model_type == "gpt_bigcode":
keys = set(["attn.c_attn"])

View File

@ -298,7 +298,7 @@ class TestPromptCache(unittest.TestCase):
):
i += 1
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=3e-2))
if __name__ == "__main__":

View File

@ -1,79 +1,97 @@
import unittest
import mlx.core as mx
from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling
from mlx_lm.sample_utils import apply_min_p, apply_top_k, apply_top_p
class TestSampleUtils(unittest.TestCase):
def test_top_p_sampling(self):
def test_apply_top_p(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
token = top_p_sampling(logits, 0.3, temperature).item()
self.assertEqual(token, 0)
new_logits = apply_top_p(logits, 0.3)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0])
token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (0, 3))
new_logits = apply_top_p(logits, 0.95)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertTrue(mx.allclose(probs.squeeze(), actual_probs))
probs = mx.array([0.0, 0.5, 0.4, 0.1])[None]
logits = mx.log(probs)
new_logits = apply_top_p(logits, 0.4)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(actual_probs.tolist(), [0.0, 1.0, 0.0, 0.0])
token = top_p_sampling(logits, 0.4, temperature).item()
self.assertEqual(token, 1)
new_logits = apply_top_p(logits, 0.6)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(
[round(p, 4) for p in actual_probs.tolist()], [0.0, 0.5556, 0.4444, 0.0]
)
token = top_p_sampling(logits, 0.6, temperature).item()
self.assertTrue(token in (1, 2))
new_logits = apply_top_p(logits, 0.95)
actual_probs = mx.softmax(new_logits.squeeze())
actual_rounded = [round(p, 4) for p in actual_probs.tolist()]
expected_rounded = [0.0, 0.5, 0.4, 0.1]
self.assertEqual(actual_rounded, expected_rounded)
self.assertAlmostEqual(sum(actual_probs.tolist()), 1.0)
token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.1, 0.1]])
logits = mx.log(probs)
new_logits = apply_top_p(logits, 0.5)
actual_probs = mx.softmax(new_logits, axis=-1)
self.assertEqual(
actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]
)
def test_apply_min_p(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
new_logits = apply_min_p(logits, 0.8)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0])
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
new_logits = apply_min_p(logits, 0.05)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertTrue(mx.allclose(actual_probs, mx.squeeze(probs)))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_p_sampling(logits, 0.5, temperature)
self.assertEqual(tokens.tolist(), [0, 1])
new_logits = apply_min_p(logits, 0.7)
actual_probs = mx.softmax(new_logits, axis=-1)
self.assertEqual(
actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]
)
def test_min_p_sampling(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
token = min_p_sampling(logits, 0.8)
self.assertEqual(token, 0)
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
for _ in range(5):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = min_p_sampling(logits, 0.7)
self.assertEqual(tokens.tolist(), [0, 1])
def test_top_k_sampling(self):
def test_apply_top_k(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
token = top_k_sampling(logits, 1).item()
self.assertEqual(token, 0)
new_logits = apply_top_k(logits, 1)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0])
probs = mx.array([0.5, 0.0, 0.0, 0.5])[None]
tokens = set()
for _ in range(100):
token = top_k_sampling(logits, 2)
tokens.add(token.item())
self.assertEqual(tokens, {0, 3})
probs = mx.array([0.6, 0.0, 0.1, 0.3])[None]
logits = mx.log(probs)
new_logits = apply_top_k(logits, 2)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(
[round(p, 4) for p in actual_probs.tolist()], [0.6667, 0.0, 0.0, 0.3333]
)
# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = top_k_sampling(logits, 1)
self.assertEqual(tokens.tolist(), [0, 1])
new_logits = apply_top_k(logits, 1)
actual_probs = mx.softmax(new_logits, axis=-1)
self.assertEqual(
actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]
)
if __name__ == "__main__":