Merge branch 'ml-explore:main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez 2025-02-12 11:09:20 +01:00 committed by GitHub
commit c26e188417
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 794 additions and 253 deletions

View File

@ -1,10 +1,10 @@
repos: repos:
- repo: https://github.com/psf/black-pre-commit-mirror - repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0 rev: 25.1.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.13.2 rev: 6.0.0
hooks: hooks:
- id: isort - id: isort
args: args:

View File

@ -45,7 +45,7 @@ Some more useful examples are listed below.
### Hugging Face ### Hugging Face
Note: You can now directly download a few converted checkpoints from the [MLX You can directly use or download converted checkpoints from the [MLX
Community](https://huggingface.co/mlx-community) organization on Hugging Face. Community](https://huggingface.co/mlx-community) organization on Hugging Face.
We encourage you to join the community and [contribute new We encourage you to join the community and [contribute new
models](https://github.com/ml-explore/mlx-examples/issues/155). models](https://github.com/ml-explore/mlx-examples/issues/155).

View File

@ -164,7 +164,7 @@ mlx_lm.convert \
``` ```
Models can also be converted and quantized directly in the Models can also be converted and quantized directly in the
[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging [mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging
Face Space. Face Space.
### Long Prompts and Generations ### Long Prompts and Generations

View File

@ -101,6 +101,14 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`. `--resume-adapter-file <path_to_adapters.safetensors>`.
#### Prompt Masking
The default training computes a loss for every token in the sample. You can
ignore the prompt and compute loss for just the completion by passing
`--mask-prompt`. Note this is only supported for `chat` and `completion`
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.
### Evaluate ### Evaluate
To compute test set perplexity use: To compute test set perplexity use:
@ -315,11 +323,27 @@ hf_dataset:
- Use `prompt_feature` and `completion_feature` to specify keys for a - Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text` `completions` dataset. Use `text_feature` to specify the key for a `text`
dataset. dataset. Use `chat_feature` to specify the key for a chat dataset.
- To specify the train, valid, or test splits, set the corresponding - To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument. `{train,valid,test}_split` argument.
You can specify a list of Hugging Face datasets with a list of records each
with the same structure as above. For example:
```yaml
hf_dataset:
- name: "Open-Orca/OpenOrca"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
prompt_feature: "question"
completion_feature: "response"
- name: "trl-lib/ultrafeedback_binarized"
train_split: "train[:90%]"
valid_split: "train[-10%:]"
chat_feature: "chosen"
```
- Arguments specified in `config` will be passed as keyword arguments to - Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).

View File

@ -152,7 +152,7 @@ def main():
print("Saving...") print("Saving...")
metadata = {} metadata = {}
metadata["model"] = args.model metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template metadata["chat_template"] = json.dumps(tokenizer.chat_template)
metadata["tokenizer_config"] = json.dumps(tokenizer_config) metadata["tokenizer_config"] = json.dumps(tokenizer_config)
save_prompt_cache(args.prompt_cache_file, cache, metadata) save_prompt_cache(args.prompt_cache_file, cache, metadata)

View File

@ -295,7 +295,9 @@ class MLXLM(LM):
completions = [] completions = []
for context, until in tqdm(zip(contexts, untils), total=len(contexts)): for context, until in tqdm(zip(contexts, untils), total=len(contexts)):
context = self._tokenize(context) context = self.tokenizer.encode(
context, add_special_tokens=not self.use_chat_template
)
max_tokens = min( max_tokens = min(
self._max_tokens, self._max_tokens,
self.tokenizer.model_max_length - len(context), self.tokenizer.model_max_length - len(context),

View File

@ -23,7 +23,6 @@ response = generate(
tokenizer, tokenizer,
prompt=prompt, prompt=prompt,
verbose=True, verbose=True,
temp=0.0,
prompt_cache=prompt_cache, prompt_cache=prompt_cache,
) )

View File

@ -4,10 +4,11 @@
Run with: Run with:
``` ```
/path/to/mpirun \ mlx.launch \
-np 2 \
--hostfile /path/to/hosts.txt \ --hostfile /path/to/hosts.txt \
python /path/to/pipeline_generate.py --prompt "hello world" --backend mpi \
/path/to/pipeline_generate.py \
--prompt "hello world"
``` ```
Make sure you can run MLX over MPI on two hosts. For more information see the Make sure you can run MLX over MPI on two hosts. For more information see the
@ -17,62 +18,110 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
""" """
import argparse import argparse
import json
from pathlib import Path
import mlx.core as mx import mlx.core as mx
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from mlx_lm import load, stream_generate from mlx_lm import load, stream_generate
from mlx_lm.utils import load_model, load_tokenizer
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
model, tokenizer = load(args.model, lazy=True)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
group = mx.distributed.init()
rank = group.rank()
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
def rprint(*args, **kwargs): def download(repo: str, allow_patterns: list[str]) -> Path:
if rank == 0: return Path(
print(*args, **kwargs) snapshot_download(
repo,
allow_patterns=allow_patterns,
)
)
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): def shard_and_load(repo):
rprint(response.text, end="", flush=True) # Get model path with everything but weight safetensors
model_path = download(
args.model,
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
)
rprint() # Lazy load and shard model to figure out
rprint("=" * 10) # which weights we need
rprint( model, _ = load_model(model_path, lazy=True, strict=False)
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec" group = mx.distributed.init(backend="mpi")
) rank = group.rank()
rprint( model.model.pipeline(group)
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec" # Figure out which files we need for the local shard
) with open(model_path / "model.safetensors.index.json", "r") as fid:
rprint(f"Peak memory: {response.peak_memory:.3f} GB") weight_index = json.load(fid)["weight_map"]
local_files = set()
for k, _ in tree_flatten(model.parameters()):
local_files.add(weight_index[k])
# Download weights for local shard
download(args.model, allow_patterns=local_files)
# Load and shard the model, and load the weights
tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path, lazy=True, strict=False)
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
group = mx.distributed.init(backend="mpi")
rank = group.rank()
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
model, tokenizer = shard_and_load(args.model)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model, tokenizer, prompt, max_tokens=args.max_tokens
):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@ -93,6 +93,12 @@ def setup_arg_parser():
action="store_true", action="store_true",
help="Use the default chat template", help="Use the default chat template",
) )
parser.add_argument(
"--chat-template-config",
help="Additional config for `apply_chat_template`. Should be a dictionary of"
" string keys to values represented as a JSON decodable string.",
default=None,
)
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
type=str2bool, type=str2bool,
@ -149,7 +155,6 @@ def setup_arg_parser():
def main(): def main():
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
mx.random.seed(args.seed) mx.random.seed(args.seed)
# Load the prompt cache and metadata if a cache file is provided # Load the prompt cache and metadata if a cache file is provided
@ -195,11 +200,15 @@ def main():
for eos_token in args.extra_eos_token: for eos_token in args.extra_eos_token:
tokenizer.add_eos_token(eos_token) tokenizer.add_eos_token(eos_token)
template_kwargs = {}
if args.chat_template_config is not None:
template_kwargs = json.loads(args.chat_template_config)
if args.use_default_chat_template: if args.use_default_chat_template:
if tokenizer.chat_template is None: if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache: elif using_cache:
tokenizer.chat_template = metadata["chat_template"] tokenizer.chat_template = json.loads(metadata["chat_template"])
prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t")
prompt = sys.stdin.read() if prompt == "-" else prompt prompt = sys.stdin.read() if prompt == "-" else prompt
@ -209,8 +218,12 @@ def main():
else: else:
messages = [] messages = []
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages,
tokenize=False,
add_generation_prompt=True,
**template_kwargs,
) )
# Treat the prompt as a suffix assuming that the prefix is in the # Treat the prompt as a suffix assuming that the prefix is in the

View File

@ -94,6 +94,14 @@ def build_parser():
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)
parser.add_argument( parser.add_argument(
"--num-layers", "--num-layers",
type=int, type=int,
@ -219,6 +227,7 @@ def train_model(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
) )
) )
# Train model # Train model
train( train(
model=model, model=model,

View File

@ -282,12 +282,12 @@ class MoEGate(nn.Module):
if self.topk_method == "group_limited_greedy": if self.topk_method == "group_limited_greedy":
bsz, seq_len = x.shape[:2] bsz, seq_len = x.shape[:2]
scores = scores.reshape(bsz, seq_len, self.n_group, -1) scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = scores.max(axis=-1) group_scores = scores.max(axis=-1, keepdims=True)
k = self.n_group - self.topk_group k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) scores = mx.put_along_axis(
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) scores, group_idx, mx.array(0.0, scores.dtype), axis=-2
scores[batch_idx, seq_idx, group_idx] = 0.0 )
scores = scores.reshape(bsz, seq_len, -1) scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k k = self.top_k
@ -364,8 +364,32 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer(config, idx) DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
] ]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = len(self.layers) // self.pipeline_size
extra = len(self.layers) - layers_per_rank * self.pipeline_size
if self.pipeline_rank < extra:
layers_per_rank += 1
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__( def __call__(
self, self,
x: mx.array, x: mx.array,
@ -374,14 +398,31 @@ class DeepseekV2Model(nn.Module):
) -> mx.array: ) -> mx.array:
h = self.embed_tokens(x) h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None: if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * self.num_layers
for layer, c in zip(self.layers, cache): # Receive from the previous process in the pipeline
h = layer(h, mask, c) if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h) return self.norm(h)
@ -418,4 +459,4 @@ class Model(nn.Module):
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module):
return down_proj return down_proj
@mx.compile
def group_expert_select(
gates,
e_score_correction_bias,
top_k,
n_group,
topk_group,
routed_scaling_factor,
norm_topk_prob,
):
k = top_k
scores = mx.sigmoid(gates.astype(mx.float32))
scores = scores + e_score_correction_bias
scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1))
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True)
k = n_group - topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :]
scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2)
scores = mx.flatten(scores, -2, -1)
k = top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if top_k > 1 and norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * routed_scaling_factor
return inds, scores
class MoEGate(nn.Module): class MoEGate(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
super().__init__() super().__init__()
@ -279,38 +311,22 @@ class MoEGate(nn.Module):
self.norm_topk_prob = config.norm_topk_prob self.norm_topk_prob = config.norm_topk_prob
self.n_routed_experts = config.n_routed_experts self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group self.n_group = config.n_group
self.topk_group = config.topk_group self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) self.e_score_correction_bias = mx.zeros((self.n_routed_experts,))
assert config.topk_method == "noaux_tc", "Unsupported topk method."
def __call__(self, x): def __call__(self, x):
gates = x @ self.weight.T return group_expert_select(
x @ self.weight.T,
scores = mx.sigmoid(gates.astype(mx.float32)) self.e_score_correction_bias,
self.top_k,
assert self.topk_method == "noaux_tc", "Unsupported topk method." self.n_group,
bsz, seq_len = x.shape[:2] self.topk_group,
scores = scores + self.e_score_correction_bias self.routed_scaling_factor,
scores = scores.reshape(bsz, seq_len, self.n_group, -1) self.norm_topk_prob,
group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1) )
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
if self.top_k > 1 and self.norm_topk_prob:
denominator = scores.sum(axis=-1, keepdims=True) + 1e-20
scores = scores / denominator
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV3MoE(nn.Module): class DeepseekV3MoE(nn.Module):
@ -381,6 +397,10 @@ class DeepseekV3Model(nn.Module):
DeepseekV3DecoderLayer(config, idx) DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
] ]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0 self.pipeline_rank = 0
self.pipeline_size = 1 self.pipeline_size = 1
@ -390,11 +410,15 @@ class DeepseekV3Model(nn.Module):
# rank=pipeline_size-1 gets the first # rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank() self.pipeline_rank = group.rank()
self.pipeline_size = group.size() self.pipeline_size = group.size()
layers_per_rank = ( layers_per_rank = len(self.layers) // self.pipeline_size
len(self.layers) + self.pipeline_size - 1 extra = len(self.layers) - layers_per_rank * self.pipeline_size
) // self.pipeline_size if self.pipeline_rank < extra:
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank layers_per_rank += 1
self.layers = self.layers[start : start + layers_per_rank] self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__( def __call__(
self, self,
@ -412,15 +436,15 @@ class DeepseekV3Model(nn.Module):
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * self.num_layers
# Receive from the previous process in the pipeline # Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1: if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for layer, c in zip(self.layers, cache): for i in range(self.num_layers):
h = layer(h, mask, c) h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline # Send to the next process in the pipeline
if pipeline_rank != 0: if pipeline_rank != 0:
@ -468,4 +492,4 @@ class Model(nn.Module):
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@ -0,0 +1,195 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
logits_scaling: float
attention_multiplier: float
embedding_multiplier: float
residual_multiplier: float
max_position_embeddings: int
num_key_value_heads: int
attention_bias: bool
mlp_bias: bool
rope_theta: float
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = args.attention_multiplier
attention_bias = args.attention_bias
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(
self.head_dim,
args.rope_theta,
False,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.residual_multiplier = args.residual_multiplier
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * self.residual_multiplier
r = self.mlp(self.post_attention_layernorm(h))
out = h + r * self.residual_multiplier
return out
class GraniteModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.embedding_multiplier = args.embedding_multiplier
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
h = self.embed_tokens(inputs) * self.embedding_multiplier
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GraniteModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.logits_scaling = args.logits_scaling
def __call__(
self,
inputs: mx.array,
mask: mx.array = None,
cache=None,
):
out = self.model(inputs, mask, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out / self.logits_scaling
@property
def layers(self):
return self.model.layers

View File

@ -1,3 +1,5 @@
# Copyright © 2025 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple

View File

@ -76,7 +76,6 @@ class Attention(nn.Module):
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5 self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
if kv_proj: if kv_proj:
self.k_proj = nn.Linear( self.k_proj = nn.Linear(
@ -107,7 +106,6 @@ class Attention(nn.Module):
B, L, D = x.shape B, L, D = x.shape
queries = self.q_proj(x) queries = self.q_proj(x)
if kv_states is None: if kv_states is None:
keys, values = self.k_proj(x), self.v_proj(x) keys, values = self.k_proj(x), self.v_proj(x)
kv_states = keys, values kv_states = keys, values
@ -198,7 +196,10 @@ class DecoderLayer(nn.Module):
super().__init__() super().__init__()
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.self_attn = Attention(kv_proj, args) self.self_attn = Attention(kv_proj, args)
self.mlp = MoeBlock(args) if args.num_experts == 1:
self.mlp = MLP(args.hidden_size, args.intermediate_size)
else:
self.mlp = MoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm( self.post_attention_layernorm = nn.RMSNorm(
@ -231,7 +232,10 @@ class HunYuanModel(nn.Module):
assert self.vocab_size > 0 assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ self.layers = [
DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0) DecoderLayer(
args=args,
kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0,
)
for i in range(args.num_hidden_layers) for i in range(args.num_hidden_layers)
] ]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
@ -251,7 +255,7 @@ class HunYuanModel(nn.Module):
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
for i, (layer, c) in enumerate(zip(self.layers, cache)): for i, (layer, c) in enumerate(zip(self.layers, cache)):
if i % self.args.cla_share_factor == 0: if (not self.args.use_cla) or i % self.args.cla_share_factor == 0:
shared_kv_states = None shared_kv_states = None
h, shared_kv_states = layer(h, mask, c, shared_kv_states) h, shared_kv_states = layer(h, mask, c, shared_kv_states)
@ -275,6 +279,29 @@ class Model(nn.Module):
return self.model.embed_tokens.as_linear(out) return self.model.embed_tokens.as_linear(out)
def sanitize(self, weights): def sanitize(self, weights):
if "model.layers.0.mlp.gate_and_up_proj.weight" in weights:
new_weights = {}
D = self.args.hidden_size
n_kv_heads = self.args.num_key_value_heads
n_kv_groups = self.args.num_attention_heads // n_kv_heads
head_dim = D // self.args.num_attention_heads
for k, v in weights.items():
if "qkv_proj" in k:
v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1)
splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1)
for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits):
k_new = k.replace("qkv_proj", k_up)
new_weights[k_new] = mx.flatten(v_new, 0, 2)
elif "gate_and_up_proj" in k:
splits = v.split(2, axis=0)
for k_up, v_new in zip(["up_proj", "gate_proj"], splits):
k_new = k.replace("gate_and_up_proj", k_up)
new_weights[k_new] = v_new
else:
new_weights[k] = v
weights = new_weights
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights return weights
for l in range(self.args.num_hidden_layers): for l in range(self.args.num_hidden_layers):

View File

@ -1,4 +1,4 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024-2025 Apple Inc.
import math import math
from dataclasses import dataclass from dataclasses import dataclass
@ -123,17 +123,16 @@ class MambaBlock(nn.Module):
self.intermediate_size, self.hidden_size, bias=args.use_bias self.intermediate_size, self.hidden_size, bias=args.use_bias
) )
def ssm_step(self, x, state=None): def ssm_step(self, x, A, state=None):
A = -mx.exp(self.A_log)
D = self.D D = self.D
deltaBC = self.x_proj(x) deltaBC = self.x_proj(x)
delta, B, C = mx.split( delta, B, C = map(
deltaBC, self.mixer_norm if self.use_bcdt_rms else lambda x: x,
indices_or_sections=[ mx.split(
self.time_step_rank, deltaBC,
self.time_step_rank + self.ssm_state_size, [self.time_step_rank, self.time_step_rank + self.ssm_state_size],
], axis=-1,
axis=-1, ),
) )
if self.use_bcdt_rms: if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C)) delta, B, C = map(self.mixer_norm, (delta, B, C))
@ -145,25 +144,40 @@ class MambaBlock(nn.Module):
y = y + D * x y = y + D * x
return y, new_state return y, new_state
def __call__(self, x, cache): def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape B, T, D = x.shape
if cache is None: xz = self.in_proj(x)
cache = [None, None] x, z = xz.split(indices_or_sections=2, axis=-1)
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
x = nn.silu(conv_out)
A = -mx.exp(self.A_log)
outputs = [] outputs = []
current_state = state_cache
y = []
for t in range(T): for t in range(T):
xt = x[:, t, :] y_t, current_state = self.ssm_step(x[:, t], A, current_state)
xz = self.in_proj(xt) y.append(y_t)
x_t, z_t = xz.split(indices_or_sections=2, axis=1) y = mx.stack(y, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) z = self.out_proj(nn.silu(z) * y)
x_t = conv_out.squeeze(1) return z, (new_conv_cache, current_state)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1]) def __call__(self, x, cache):
z_t = nn.silu(z_t) if cache is None:
output_t = y_t * z_t conv_cache, state_cache = None, None
output_t = self.out_proj(output_t) else:
outputs.append(output_t) conv_cache, state_cache = cache[0], cache[1]
output = mx.stack(outputs, axis=1)
output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)
if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache
return output return output

View File

@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2025 Apple Inc.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union

View File

@ -1,5 +1,6 @@
import json import json
from functools import partial from functools import partial
from typing import List
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None):
detokenizer_class, detokenizer_class,
eos_token_ids=eos_token_ids, eos_token_ids=eos_token_ids,
) )
def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List:
removed_bos = sequence if sequence[0] != bos else sequence[1:]
return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos

View File

@ -1,6 +1,8 @@
import itertools
import json import json
import types
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -34,14 +36,24 @@ class ChatDataset:
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): def __init__(
self._data = [ self,
tokenizer.apply_chat_template( data: List[Dict[str, str]],
d["messages"], tokenizer: PreTrainedTokenizer,
tools=d.get("tools", None), chat_key: str = "messages",
) mask_prompt: bool = False,
for d in data ):
] self._data = []
for d in data:
messages = d[chat_key]
tools = d.get("tools", None)
tokens = tokenizer.apply_chat_template(messages, tools=tools)
if mask_prompt:
messages = messages[:-1]
offset = len(tokenizer.apply_chat_template(messages, tools=tools))
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx] return self._data[idx]
@ -63,16 +75,36 @@ class CompletionsDataset:
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_key: str, prompt_key: str,
completion_key: str, completion_key: str,
mask_prompt: bool,
): ):
self._data = [ self._data = []
tokenizer.apply_chat_template( for d in data:
tokens = tokenizer.apply_chat_template(
[ [
{"role": "user", "content": d[prompt_key]}, {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[completion_key]}, {"role": "assistant", "content": d[completion_key]},
], ],
) )
for d in data if mask_prompt:
] offset = len(
tokenizer.apply_chat_template(
[{"role": "user", "content": d[prompt_key]}]
)
)
self._data.append((tokens, offset))
else:
self._data.append(tokens)
def __getitem__(self, idx: int):
return self._data[idx]
def __len__(self):
return len(self._data)
class ConcatenatedDataset:
def __init__(self, data: List[Any]):
self._data = list(itertools.chain(*data))
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx] return self._data[idx]
@ -84,18 +116,26 @@ class CompletionsDataset:
def create_dataset( def create_dataset(
data, data,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config,
completion_feature: Optional[str] = None,
): ):
prompt_feature = prompt_feature or "prompt" mask_prompt = getattr(config, "mask_prompt", False)
completion_feature = completion_feature or "completion" prompt_feature = getattr(config, "prompt_feature", "prompt")
text_feature = getattr(config, "text_feature", "text")
completion_feature = getattr(config, "completion_feature", "completion")
chat_feature = getattr(config, "chat_feature", "messages")
sample = data[0] sample = data[0]
if "messages" in sample: if prompt_feature in sample and completion_feature in sample:
return ChatDataset(data, tokenizer) return CompletionsDataset(
elif prompt_feature in sample and completion_feature in sample: data, tokenizer, prompt_feature, completion_feature, mask_prompt
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) )
elif "text" in sample: elif chat_feature in sample:
return Dataset(data, tokenizer) return ChatDataset(
data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
)
elif text_feature in sample:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(data, tokenizer, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
"Unsupported data format, check the supported formats here:\n" "Unsupported data format, check the supported formats here:\n"
@ -106,15 +146,14 @@ def create_dataset(
def load_local_dataset( def load_local_dataset(
data_path: Path, data_path: Path,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config,
completion_feature: Optional[str] = None,
): ):
def load_subset(path): def load_subset(path):
if not path.exists(): if not path.exists():
return [] return []
with open(path, "r") as fid: with open(path, "r") as fid:
data = [json.loads(l) for l in fid] data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer, prompt_feature, completion_feature) return create_dataset(data, tokenizer, config)
names = ("train", "valid", "test") names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
@ -124,8 +163,7 @@ def load_local_dataset(
def load_hf_dataset( def load_hf_dataset(
data_id: str, data_id: str,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
prompt_feature: Optional[str] = None, config,
completion_feature: Optional[str] = None,
): ):
from datasets import exceptions, load_dataset from datasets import exceptions, load_dataset
@ -136,9 +174,7 @@ def load_hf_dataset(
train, valid, test = [ train, valid, test = [
( (
create_dataset( create_dataset(dataset[n], tokenizer, config)
dataset[n], tokenizer, prompt_feature, completion_feature
)
if n in dataset.keys() if n in dataset.keys()
else [] else []
) )
@ -154,42 +190,61 @@ def load_hf_dataset(
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets import datasets
hf_args = args.hf_dataset def create_hf_dataset(dataset_name, config, split, hf_config):
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
ds = datasets.load_dataset( ds = datasets.load_dataset(
dataset_name, dataset_name,
split=split, split=split,
**hf_args.get("config", {}), **hf_config,
) )
if prompt_feature and completion_feature: return create_dataset(ds, tokenizer, config)
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature: dataset_collection = args.hf_dataset
return Dataset(ds, tokenizer, text_key=text_feature) if isinstance(dataset_collection, dict):
else: dataset_collection = [dataset_collection]
raise ValueError(
"Specify either a prompt and completion feature or a text " collection = []
"feature for the Hugging Face dataset." for ds in dataset_collection:
ds_name = ds["name"]
print(f"Loading Hugging Face dataset {ds_name}.")
ds["mask_prompt"] = getattr(args, "mask_prompt", False)
config = types.SimpleNamespace(**ds)
hf_config = ds.get("config", {})
if args.train:
train_split = ds.get("train_split", "train[:80%]")
valid_split = ds.get("valid_split", "train[-10%:]")
train = create_hf_dataset(
ds_name,
config,
train_split,
hf_config,
) )
valid = create_hf_dataset(
ds_name,
config,
valid_split,
hf_config,
)
else:
train, valid = [], []
if args.train: if args.test:
train_split = hf_args.get("train_split", "train[:80%]") test_split = ds.get("test_split")
valid_split = hf_args.get("valid_split", "train[-10%:]") test = create_hf_dataset(
train = create_hf_dataset(split=train_split) ds_name,
valid = create_hf_dataset(split=valid_split) config,
else: test_split,
train, valid = [], [] hf_config,
if args.test: )
test = create_hf_dataset(split=hf_args.get("test_split")) else:
else: test = []
test = []
return train, valid, test collection.append((train, valid, test))
if len(collection) == 1:
return collection[0]
# Otherwise concatenate them
return tuple(map(ConcatenatedDataset, zip(*collection)))
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)
prompt_feature = getattr(args, "prompt_feature", None)
completion_feature = getattr(args, "completion_feature", None)
if data_path.exists(): if data_path.exists():
train, valid, test = load_local_dataset( train, valid, test = load_local_dataset(data_path, tokenizer, args)
data_path, tokenizer, prompt_feature, completion_feature
)
else: else:
print(f"Loading Hugging Face dataset {args.data}.") print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset( train, valid, test = load_hf_dataset(args.data, tokenizer, args)
args.data, tokenizer, prompt_feature, completion_feature
)
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(

View File

@ -5,13 +5,16 @@ import shutil
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Union from typing import List, Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from .datasets import CompletionsDataset
def grad_checkpoint(layer): def grad_checkpoint(layer):
@ -63,20 +66,30 @@ class TrainingArgs:
) )
def default_loss(model, inputs, targets, lengths): def default_loss(model, batch, lengths):
inputs = batch[:, :-1]
targets = batch[:, 1:]
logits = model(inputs) logits = model(inputs)
logits = logits.astype(mx.float32) logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] steps = mx.arange(1, targets.shape[1] + 1)
mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:])
ce = nn.losses.cross_entropy(logits, targets) * length_mask ce = nn.losses.cross_entropy(logits, targets) * mask
ntoks = length_mask.sum() ntoks = mask.sum()
ce = ce.sum() / ntoks ce = ce.sum() / ntoks
return ce, ntoks return ce, ntoks
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_batches(
dataset,
tokenizer,
batch_size,
max_seq_length,
train=False,
):
# Sort by length: # Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size: if len(dataset) < batch_size:
@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
indices = np.random.permutation(len(batch_idx)) indices = np.random.permutation(len(batch_idx))
for i in indices: for i in indices:
batch = [dataset[j] for j in batch_idx[i]] batch = [dataset[j] for j in batch_idx[i]]
if len(batch[0]) == 2:
batch, offsets = zip(*batch)
else:
offsets = [0] * len(batch)
lengths = [len(x) for x in batch] lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length: if max(lengths) > max_seq_length:
print( print(
@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
truncated_length # Update lengths to match truncated lengths truncated_length # Update lengths to match truncated lengths
) )
batch = mx.array(batch_arr) batch = mx.array(batch_arr)
yield batch, mx.array(list(zip(offsets, lengths)))
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train: if not train:
break break
@ -140,8 +156,8 @@ def evaluate(
loss: callable = default_loss, loss: callable = default_loss,
iterate_batches: callable = iterate_batches, iterate_batches: callable = iterate_batches,
): ):
all_losses = 0 all_losses = mx.array(0.0)
ntokens = 0 ntokens = mx.array(0)
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
@ -217,8 +233,8 @@ def train(
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
trained_tokens = 0 trained_tokens = 0
train_time = 0
# Main training loop # Main training loop
start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
range(1, args.iters + 1), range(1, args.iters + 1),
iterate_batches( iterate_batches(
@ -229,10 +245,11 @@ def train(
train=True, train=True,
), ),
): ):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss # Report validation loss if needed, the first validation loss
# is always measured before any training. # is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter() tic = time.perf_counter()
val_loss = evaluate( val_loss = evaluate(
model=model, model=model,
dataset=val_dataset, dataset=val_dataset,
@ -243,7 +260,7 @@ def train(
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - tic
if rank == 0: if rank == 0:
print( print(
f"Iter {it}: " f"Iter {it}: "
@ -260,24 +277,23 @@ def train(
} }
training_callback.on_val_loss_report(val_info) training_callback.on_val_loss_report(val_info)
start = time.perf_counter() tic = time.perf_counter()
lvalue, toks = step(batch) lvalue, toks = step(batch)
losses += lvalue losses += lvalue
n_tokens += toks n_tokens += toks
steps += 1 steps += 1
mx.eval(state, losses, n_tokens) mx.eval(state, losses, n_tokens)
train_time += time.perf_counter() - tic
# Report training loss if needed # Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size() train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / train_time
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9 peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0: if rank == 0:
@ -306,7 +322,7 @@ def train(
losses = 0 losses = 0
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
start = time.perf_counter() train_time = 0
# Save adapter weights # Save adapter weights
if it % args.steps_per_save == 0: if it % args.steps_per_save == 0:

View File

@ -89,11 +89,13 @@ def linear_to_lora_layers(
"mixtral", "mixtral",
"nemotron", "nemotron",
"stablelm", "stablelm",
"hunyuan",
"qwen2", "qwen2",
"qwen2_moe", "qwen2_moe",
"phimoe", "phimoe",
"gemma", "gemma",
"gemma2", "gemma2",
"granite",
"helium", "helium",
"starcoder2", "starcoder2",
"cohere", "cohere",

View File

@ -13,7 +13,18 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union from typing import (
Any,
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
)
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -65,6 +76,7 @@ class GenerationResponse:
Args: Args:
text (str): The next segment of decoded text. This can be an empty string. text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token. token (int): The next token.
from_draft (bool): Whether the token was generated by the draft model.
logprobs (mx.array): A vector of log probabilities. logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt. prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second. prompt_tps (float): The prompt processing tokens-per-second.
@ -77,6 +89,7 @@ class GenerationResponse:
text: str text: str
token: int token: int
logprobs: mx.array logprobs: mx.array
from_draft: bool
prompt_tokens: int prompt_tokens: int
prompt_tps: float prompt_tps: float
generation_tokens: int generation_tokens: int
@ -338,7 +351,7 @@ def speculative_generate_step(
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -365,7 +378,8 @@ def speculative_generate_step(
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
and a bool indicating if the token was generated by the draft model
""" """
y = prompt y = prompt
@ -450,12 +464,12 @@ def speculative_generate_step(
break break
n += 1 n += 1
ntoks += 1 ntoks += 1
yield tn, lpn yield tn, lpn, True
if ntoks == max_tokens: if ntoks == max_tokens:
break break
if ntoks < max_tokens: if ntoks < max_tokens:
ntoks += 1 ntoks += 1
yield tokens[n], logprobs[n] yield tokens[n], logprobs[n], False
if ntoks == max_tokens: if ntoks == max_tokens:
break break
@ -463,7 +477,7 @@ def speculative_generate_step(
y = mx.array([tokens[n]], mx.uint32) y = mx.array([tokens[n]], mx.uint32)
draft_y = y draft_y = y
# If we accpeted all the draft tokens, include the last # If we accepted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been # draft token in the next draft step since it hasn't been
# processed yet by the draft model # processed yet by the draft model
if n == num_draft: if n == num_draft:
@ -518,6 +532,10 @@ def stream_generate(
if draft_model is None: if draft_model is None:
kwargs.pop("num_draft_tokens", None) kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs) token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
else: else:
kwargs.pop("max_kv_size", None) kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step( token_generator = speculative_generate_step(
@ -526,7 +544,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
tic = time.perf_counter() tic = time.perf_counter()
for n, (token, logprobs) in enumerate(token_generator): for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
@ -540,6 +558,7 @@ def stream_generate(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
logprobs=logprobs, logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,
generation_tokens=n + 1, generation_tokens=n + 1,
@ -553,6 +572,7 @@ def stream_generate(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
logprobs=logprobs, logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,
generation_tokens=n + 1, generation_tokens=n + 1,
@ -627,6 +647,7 @@ def load_config(model_path: Path) -> dict:
def load_model( def load_model(
model_path: Path, model_path: Path,
lazy: bool = False, lazy: bool = False,
strict: bool = True,
model_config: dict = {}, model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module: ) -> nn.Module:
@ -638,6 +659,8 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
strict (bool): Whether or not to raise an exception if weights don't
match. Default: ``True``
model_config (dict, optional): Optional configuration parameters for the model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary. model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
@ -660,7 +683,7 @@ def load_model(
# Try weight for back-compat # Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors")) weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files: if not weight_files and strict:
logging.error(f"No safetensors found in {model_path}") logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}")
@ -694,7 +717,7 @@ def load_model(
class_predicate=class_predicate, class_predicate=class_predicate,
) )
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()), strict=strict)
if not lazy: if not lazy:
mx.eval(model.parameters()) mx.eval(model.parameters())

View File

@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(isinstance(train, datasets.ChatDataset)) self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self): def test_hf(self):
hf_args = {
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
}
args = types.SimpleNamespace( args = types.SimpleNamespace(
hf_dataset={ hf_dataset=hf_args,
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
},
test=False, test=False,
train=True, train=True,
) )
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0) self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0) self.assertEqual(len(test), 0)
args = types.SimpleNamespace(
hf_dataset=[hf_args, hf_args],
test=False,
train=True,
)
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
self.assertEqual(2 * len(train), len(train_double))
self.assertEqual(2 * len(valid), len(valid_double))
self.assertEqual(2 * len(test), len(test_double))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1,17 +1,24 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import unittest import unittest
from typing import List
from mlx_lm.sample_utils import make_logits_processors from mlx_lm.sample_utils import make_logits_processors
from mlx_lm.utils import generate, load from mlx_lm.utils import (
GenerationResponse,
generate,
load,
make_sampler,
stream_generate,
)
class TestGenerate(unittest.TestCase): class TestGenerate(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH) cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
def test_generate(self): def test_generate(self):
# Simple test that generation runs # Simple test that generation runs
@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
) )
self.assertEqual(len(all_toks), len(init_toks) + 5) self.assertEqual(len(all_toks), len(init_toks) + 5)
def test_stream_generate_speculative(self):
# Use same model as draft model, this is not a speed test
draft_model, _ = load(self.HF_MODEL_PATH)
results: List[GenerationResponse] = []
drafted: List[bool] = []
# make a determinate sampler
sampler = make_sampler(temp=0.0)
for generation_result in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt="hello",
max_tokens=5,
draft_model=draft_model,
num_draft_tokens=2,
sampler=sampler,
):
drafted.append(generation_result.from_draft)
results.append(generation_result)
self.assertEqual(len(results), 5)
# since num_draft_tokens is 2 and draft model is the same, the
# first 2 generations should be drafts, the third should come
# from the target model, and last two should be drafts
self.assertEqual(drafted, [True, True, False, True, True])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -134,9 +134,7 @@ def find_alignment(
logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :])
# consider only the logits associated with predicting text # consider only the logits associated with predicting text
sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot]
token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( token_probs = mx.softmax(sampled_logits, precise=True, axis=-1)
sampled_logits.dtype
)
text_token_probs = mx.take_along_axis( text_token_probs = mx.take_along_axis(
token_probs, mx.array(text_tokens)[:, None], axis=1 token_probs, mx.array(text_tokens)[:, None], axis=1
).squeeze(1) ).squeeze(1)
@ -144,10 +142,11 @@ def find_alignment(
# heads * tokens * frames # heads * tokens * frames
weights = mx.stack( weights = mx.stack(
[cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] [cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()]
) )
weights = weights[:, :, : num_frames // 2] weights = weights[:, :, : num_frames // 2]
weights = mx.softmax(weights * qk_scale, axis=-1) weights = mx.softmax(weights * qk_scale, axis=-1, precise=True)
weights = weights.astype(mx.float32)
mean = mx.mean(weights, axis=-2, keepdims=True) mean = mx.mean(weights, axis=-2, keepdims=True)
std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt()
weights = (weights - mean) / std weights = (weights - mean) / std

View File

@ -195,6 +195,8 @@ def transcribe(
seek_points.append(0) seek_points.append(0)
if len(seek_points) % 2 == 1: if len(seek_points) % 2 == 1:
seek_points.append(content_frames) seek_points.append(content_frames)
else:
seek_points[-1] = min(content_frames, seek_points[-1])
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、" punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"

View File

@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module):
w = mx.softmax(qk, axis=-1, precise=True) w = mx.softmax(qk, axis=-1, precise=True)
out = (w @ v).transpose(0, 2, 1, 3) out = (w @ v).transpose(0, 2, 1, 3)
out = out.reshape(n_batch, n_ctx, n_state) out = out.reshape(n_batch, n_ctx, n_state)
return out, qk.astype(mx.float32) return out, qk
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):