diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 012a96ad..7a1bf9bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,10 @@ repos: - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.8.0 + rev: 25.1.0 hooks: - id: black - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 6.0.0 hooks: - id: isort args: diff --git a/README.md b/README.md index 88888ad0..e47bd598 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Some more useful examples are listed below. ### Hugging Face -Note: You can now directly download a few converted checkpoints from the [MLX +You can directly use or download converted checkpoints from the [MLX Community](https://huggingface.co/mlx-community) organization on Hugging Face. We encourage you to join the community and [contribute new models](https://github.com/ml-explore/mlx-examples/issues/155). diff --git a/llms/README.md b/llms/README.md index e943ed69..4f7451c1 100644 --- a/llms/README.md +++ b/llms/README.md @@ -164,7 +164,7 @@ mlx_lm.convert \ ``` Models can also be converted and quantized directly in the -[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging +[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging Face Space. ### Long Prompts and Generations diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index d90a64f4..15b21b10 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -101,6 +101,14 @@ You can specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with `--resume-adapter-file `. +#### Prompt Masking + +The default training computes a loss for every token in the sample. You can +ignore the prompt and compute loss for just the completion by passing +`--mask-prompt`. Note this is only supported for `chat` and `completion` +datasets. For `chat` datasets the final message in the message list is +considered the completion. See the [dataset section](#Data) for more details. + ### Evaluate To compute test set perplexity use: @@ -315,11 +323,27 @@ hf_dataset: - Use `prompt_feature` and `completion_feature` to specify keys for a `completions` dataset. Use `text_feature` to specify the key for a `text` - dataset. + dataset. Use `chat_feature` to specify the key for a chat dataset. - To specify the train, valid, or test splits, set the corresponding `{train,valid,test}_split` argument. +You can specify a list of Hugging Face datasets with a list of records each +with the same structure as above. For example: + +```yaml +hf_dataset: + - name: "Open-Orca/OpenOrca" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + prompt_feature: "question" + completion_feature: "response" + - name: "trl-lib/ultrafeedback_binarized" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + chat_feature: "chosen" +``` + - Arguments specified in `config` will be passed as keyword arguments to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index c18f1bae..fff64f78 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -152,7 +152,7 @@ def main(): print("Saving...") metadata = {} metadata["model"] = args.model - metadata["chat_template"] = tokenizer.chat_template + metadata["chat_template"] = json.dumps(tokenizer.chat_template) metadata["tokenizer_config"] = json.dumps(tokenizer_config) save_prompt_cache(args.prompt_cache_file, cache, metadata) diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index ca5e83bb..2f35ade2 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -295,7 +295,9 @@ class MLXLM(LM): completions = [] for context, until in tqdm(zip(contexts, untils), total=len(contexts)): - context = self._tokenize(context) + context = self.tokenizer.encode( + context, add_special_tokens=not self.use_chat_template + ) max_tokens = min( self._max_tokens, self.tokenizer.model_max_length - len(context), diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index 4a7020f1..dcd90b67 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -23,7 +23,6 @@ response = generate( tokenizer, prompt=prompt, verbose=True, - temp=0.0, prompt_cache=prompt_cache, ) diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py index 2970b986..1e4fb445 100644 --- a/llms/mlx_lm/examples/pipeline_generate.py +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -4,10 +4,11 @@ Run with: ``` -/path/to/mpirun \ - -np 2 \ +mlx.launch \ --hostfile /path/to/hosts.txt \ - python /path/to/pipeline_generate.py --prompt "hello world" + --backend mpi \ + /path/to/pipeline_generate.py \ + --prompt "hello world" ``` Make sure you can run MLX over MPI on two hosts. For more information see the @@ -17,62 +18,110 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html). """ import argparse +import json +from pathlib import Path import mlx.core as mx +from huggingface_hub import snapshot_download +from mlx.utils import tree_flatten from mlx_lm import load, stream_generate - -parser = argparse.ArgumentParser(description="LLM pipelined inference example") -parser.add_argument( - "--model", - default="mlx-community/DeepSeek-R1-3bit", - help="HF repo or path to local model.", -) -parser.add_argument( - "--prompt", - "-p", - default="Write a quicksort in C++.", - help="Message to be processed by the model ('-' reads from stdin)", -) -parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=256, - help="Maximum number of tokens to generate", -) -args = parser.parse_args() - -model, tokenizer = load(args.model, lazy=True) - -messages = [{"role": "user", "content": args.prompt}] -prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - -group = mx.distributed.init() -rank = group.rank() -model.model.pipeline(group) -mx.eval(model.parameters()) - -# Synchronize processes before generation to avoid timeout if downloading -# model for the first time. -mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) +from mlx_lm.utils import load_model, load_tokenizer -def rprint(*args, **kwargs): - if rank == 0: - print(*args, **kwargs) +def download(repo: str, allow_patterns: list[str]) -> Path: + return Path( + snapshot_download( + repo, + allow_patterns=allow_patterns, + ) + ) -for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): - rprint(response.text, end="", flush=True) +def shard_and_load(repo): + # Get model path with everything but weight safetensors + model_path = download( + args.model, + allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"], + ) -rprint() -rprint("=" * 10) -rprint( - f"Prompt: {response.prompt_tokens} tokens, " - f"{response.prompt_tps:.3f} tokens-per-sec" -) -rprint( - f"Generation: {response.generation_tokens} tokens, " - f"{response.generation_tps:.3f} tokens-per-sec" -) -rprint(f"Peak memory: {response.peak_memory:.3f} GB") + # Lazy load and shard model to figure out + # which weights we need + model, _ = load_model(model_path, lazy=True, strict=False) + + group = mx.distributed.init(backend="mpi") + rank = group.rank() + model.model.pipeline(group) + + # Figure out which files we need for the local shard + with open(model_path / "model.safetensors.index.json", "r") as fid: + weight_index = json.load(fid)["weight_map"] + + local_files = set() + for k, _ in tree_flatten(model.parameters()): + local_files.add(weight_index[k]) + + # Download weights for local shard + download(args.model, allow_patterns=local_files) + + # Load and shard the model, and load the weights + tokenizer = load_tokenizer(model_path) + model, _ = load_model(model_path, lazy=True, strict=False) + model.model.pipeline(group) + mx.eval(model.parameters()) + + # Synchronize processes before generation to avoid timeout if downloading + # model for the first time. + mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LLM pipelined inference example") + parser.add_argument( + "--model", + default="mlx-community/DeepSeek-R1-3bit", + help="HF repo or path to local model.", + ) + parser.add_argument( + "--prompt", + "-p", + default="Write a quicksort in C++.", + help="Message to be processed by the model ('-' reads from stdin)", + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=256, + help="Maximum number of tokens to generate", + ) + args = parser.parse_args() + + group = mx.distributed.init(backend="mpi") + rank = group.rank() + + def rprint(*args, **kwargs): + if rank == 0: + print(*args, **kwargs) + + model, tokenizer = shard_and_load(args.model) + + messages = [{"role": "user", "content": args.prompt}] + prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + + for response in stream_generate( + model, tokenizer, prompt, max_tokens=args.max_tokens + ): + rprint(response.text, end="", flush=True) + + rprint() + rprint("=" * 10) + rprint( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" + ) + rprint( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" + ) + rprint(f"Peak memory: {response.peak_memory:.3f} GB") diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0d286c75..d8f97e5e 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -93,6 +93,12 @@ def setup_arg_parser(): action="store_true", help="Use the default chat template", ) + parser.add_argument( + "--chat-template-config", + help="Additional config for `apply_chat_template`. Should be a dictionary of" + " string keys to values represented as a JSON decodable string.", + default=None, + ) parser.add_argument( "--verbose", type=str2bool, @@ -149,7 +155,6 @@ def setup_arg_parser(): def main(): parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) # Load the prompt cache and metadata if a cache file is provided @@ -195,11 +200,15 @@ def main(): for eos_token in args.extra_eos_token: tokenizer.add_eos_token(eos_token) + template_kwargs = {} + if args.chat_template_config is not None: + template_kwargs = json.loads(args.chat_template_config) + if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template elif using_cache: - tokenizer.chat_template = metadata["chat_template"] + tokenizer.chat_template = json.loads(metadata["chat_template"]) prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") prompt = sys.stdin.read() if prompt == "-" else prompt @@ -209,8 +218,12 @@ def main(): else: messages = [] messages.append({"role": "user", "content": prompt}) + prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, + tokenize=False, + add_generation_prompt=True, + **template_kwargs, ) # Treat the prompt as a suffix assuming that the prefix is in the diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 43f508c3..abc5dfa9 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -94,6 +94,14 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) + + parser.add_argument( + "--mask-prompt", + action="store_true", + help="Mask the prompt in the loss when training", + default=False, + ) + parser.add_argument( "--num-layers", type=int, @@ -219,6 +227,7 @@ def train_model( build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate ) ) + # Train model train( model=model, diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 9027da7e..7a5bdeb1 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -282,12 +282,12 @@ class MoEGate(nn.Module): if self.topk_method == "group_limited_greedy": bsz, seq_len = x.shape[:2] scores = scores.reshape(bsz, seq_len, self.n_group, -1) - group_scores = scores.max(axis=-1) + group_scores = scores.max(axis=-1, keepdims=True) k = self.n_group - self.topk_group - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] - batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) - seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) - scores[batch_idx, seq_idx, group_idx] = 0.0 + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] + scores = mx.put_along_axis( + scores, group_idx, mx.array(0.0, scores.dtype), axis=-2 + ) scores = scores.reshape(bsz, seq_len, -1) k = self.top_k @@ -364,8 +364,32 @@ class DeepseekV2Model(nn.Module): DeepseekV2DecoderLayer(config, idx) for idx in range(config.num_hidden_layers) ] + self.start_idx = 0 + self.end_idx = len(self.layers) + self.num_layers = self.end_idx + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pipeline_rank = 0 + self.pipeline_size = 1 + + def pipeline(self, group): + # Split layers in reverse so rank=0 gets the last layers and + # rank=pipeline_size-1 gets the first + self.pipeline_rank = group.rank() + self.pipeline_size = group.size() + layers_per_rank = len(self.layers) // self.pipeline_size + extra = len(self.layers) - layers_per_rank * self.pipeline_size + if self.pipeline_rank < extra: + layers_per_rank += 1 + + self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank + self.end_idx = self.start_idx + layers_per_rank + self.num_layers = layers_per_rank + self.layers = self.layers[: self.end_idx] + self.layers[: self.start_idx] = [None] * self.start_idx + self.num_layers = len(self.layers) - self.start_idx + def __call__( self, x: mx.array, @@ -374,14 +398,31 @@ class DeepseekV2Model(nn.Module): ) -> mx.array: h = self.embed_tokens(x) + pipeline_rank = self.pipeline_rank + pipeline_size = self.pipeline_size + # Hack to avoid time-outs during prompt-processing + dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu if mask is None: mask = create_attention_mask(h, cache) if cache is None: - cache = [None] * len(self.layers) + cache = [None] * self.num_layers - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) + # Receive from the previous process in the pipeline + if pipeline_rank < pipeline_size - 1: + h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) + + for i in range(self.num_layers): + h = self.layers[self.start_idx + i](h, mask, cache[i]) + + # Send to the next process in the pipeline + if pipeline_rank != 0: + h = mx.distributed.send( + h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream + ) + + # Broadcast h while keeping it in the graph + h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]] return self.norm(h) @@ -418,4 +459,4 @@ class Model(nn.Module): @property def layers(self): - return self.model.layers + return self.model.layers[self.model.start_idx : self.model.end_idx] diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index 96ce4f85..47e17236 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module): return down_proj +@mx.compile +def group_expert_select( + gates, + e_score_correction_bias, + top_k, + n_group, + topk_group, + routed_scaling_factor, + norm_topk_prob, +): + + k = top_k + scores = mx.sigmoid(gates.astype(mx.float32)) + scores = scores + e_score_correction_bias + scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1)) + group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True) + k = n_group - topk_group + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] + scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2) + scores = mx.flatten(scores, -2, -1) + + k = top_k + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(scores, inds, axis=-1) + if top_k > 1 and norm_topk_prob: + denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 + scores = scores / denominator + scores = scores * routed_scaling_factor + + return inds, scores + + class MoEGate(nn.Module): def __init__(self, config: ModelArgs): super().__init__() @@ -279,38 +311,22 @@ class MoEGate(nn.Module): self.norm_topk_prob = config.norm_topk_prob self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor - self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) + assert config.topk_method == "noaux_tc", "Unsupported topk method." def __call__(self, x): - gates = x @ self.weight.T - - scores = mx.sigmoid(gates.astype(mx.float32)) - - assert self.topk_method == "noaux_tc", "Unsupported topk method." - bsz, seq_len = x.shape[:2] - scores = scores + self.e_score_correction_bias - scores = scores.reshape(bsz, seq_len, self.n_group, -1) - group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1) - k = self.n_group - self.topk_group - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] - batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) - seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) - scores[batch_idx, seq_idx, group_idx] = 0.0 - scores = scores.reshape(bsz, seq_len, -1) - - k = self.top_k - inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] - scores = mx.take_along_axis(scores, inds, axis=-1) - if self.top_k > 1 and self.norm_topk_prob: - denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 - scores = scores / denominator - scores = scores * self.routed_scaling_factor - - return inds, scores + return group_expert_select( + x @ self.weight.T, + self.e_score_correction_bias, + self.top_k, + self.n_group, + self.topk_group, + self.routed_scaling_factor, + self.norm_topk_prob, + ) class DeepseekV3MoE(nn.Module): @@ -381,6 +397,10 @@ class DeepseekV3Model(nn.Module): DeepseekV3DecoderLayer(config, idx) for idx in range(config.num_hidden_layers) ] + self.start_idx = 0 + self.end_idx = len(self.layers) + self.num_layers = self.end_idx + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pipeline_rank = 0 self.pipeline_size = 1 @@ -390,11 +410,15 @@ class DeepseekV3Model(nn.Module): # rank=pipeline_size-1 gets the first self.pipeline_rank = group.rank() self.pipeline_size = group.size() - layers_per_rank = ( - len(self.layers) + self.pipeline_size - 1 - ) // self.pipeline_size - start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank - self.layers = self.layers[start : start + layers_per_rank] + layers_per_rank = len(self.layers) // self.pipeline_size + extra = len(self.layers) - layers_per_rank * self.pipeline_size + if self.pipeline_rank < extra: + layers_per_rank += 1 + self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank + self.end_idx = self.start_idx + layers_per_rank + self.layers = self.layers[: self.end_idx] + self.layers[: self.start_idx] = [None] * self.start_idx + self.num_layers = len(self.layers) - self.start_idx def __call__( self, @@ -412,15 +436,15 @@ class DeepseekV3Model(nn.Module): mask = create_attention_mask(h, cache) if cache is None: - cache = [None] * len(self.layers) + cache = [None] * self.num_layers # Receive from the previous process in the pipeline if pipeline_rank < pipeline_size - 1: h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) + for i in range(self.num_layers): + h = self.layers[self.start_idx + i](h, mask, cache[i]) # Send to the next process in the pipeline if pipeline_rank != 0: @@ -468,4 +492,4 @@ class Model(nn.Module): @property def layers(self): - return self.model.layers + return self.model.layers[self.model.start_idx : self.model.end_idx] diff --git a/llms/mlx_lm/models/granite.py b/llms/mlx_lm/models/granite.py new file mode 100644 index 00000000..43597d99 --- /dev/null +++ b/llms/mlx_lm/models/granite.py @@ -0,0 +1,195 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + logits_scaling: float + attention_multiplier: float + embedding_multiplier: float + residual_multiplier: float + max_position_embeddings: int + num_key_value_heads: int + attention_bias: bool + mlp_bias: bool + rope_theta: float + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.hidden_size // n_heads + + self.scale = args.attention_multiplier + attention_bias = args.attention_bias + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + False, + args.rope_scaling, + args.max_position_embeddings, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.residual_multiplier = args.residual_multiplier + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r * self.residual_multiplier + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r * self.residual_multiplier + return out + + +class GraniteModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.embedding_multiplier = args.embedding_multiplier + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + h = self.embed_tokens(inputs) * self.embedding_multiplier + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = GraniteModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.logits_scaling = args.logits_scaling + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + out = self.model(inputs, mask, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out / self.logits_scaling + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py index 6ca46a72..ff551bca 100644 --- a/llms/mlx_lm/models/helium.py +++ b/llms/mlx_lm/models/helium.py @@ -1,3 +1,5 @@ +# Copyright © 2025 Apple Inc. + from dataclasses import dataclass from typing import Any, Optional, Tuple diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py index f9dc5652..122cebda 100644 --- a/llms/mlx_lm/models/hunyuan.py +++ b/llms/mlx_lm/models/hunyuan.py @@ -76,7 +76,6 @@ class Attention(nn.Module): head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) if kv_proj: self.k_proj = nn.Linear( @@ -107,7 +106,6 @@ class Attention(nn.Module): B, L, D = x.shape queries = self.q_proj(x) - if kv_states is None: keys, values = self.k_proj(x), self.v_proj(x) kv_states = keys, values @@ -198,7 +196,10 @@ class DecoderLayer(nn.Module): super().__init__() self.hidden_size = args.hidden_size self.self_attn = Attention(kv_proj, args) - self.mlp = MoeBlock(args) + if args.num_experts == 1: + self.mlp = MLP(args.hidden_size, args.intermediate_size) + else: + self.mlp = MoeBlock(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm( @@ -231,7 +232,10 @@ class HunYuanModel(nn.Module): assert self.vocab_size > 0 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [ - DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0) + DecoderLayer( + args=args, + kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0, + ) for i in range(args.num_hidden_layers) ] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) @@ -251,7 +255,7 @@ class HunYuanModel(nn.Module): cache = [None] * len(self.layers) for i, (layer, c) in enumerate(zip(self.layers, cache)): - if i % self.args.cla_share_factor == 0: + if (not self.args.use_cla) or i % self.args.cla_share_factor == 0: shared_kv_states = None h, shared_kv_states = layer(h, mask, c, shared_kv_states) @@ -275,6 +279,29 @@ class Model(nn.Module): return self.model.embed_tokens.as_linear(out) def sanitize(self, weights): + + if "model.layers.0.mlp.gate_and_up_proj.weight" in weights: + new_weights = {} + D = self.args.hidden_size + n_kv_heads = self.args.num_key_value_heads + n_kv_groups = self.args.num_attention_heads // n_kv_heads + head_dim = D // self.args.num_attention_heads + for k, v in weights.items(): + if "qkv_proj" in k: + v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1) + splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1) + for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits): + k_new = k.replace("qkv_proj", k_up) + new_weights[k_new] = mx.flatten(v_new, 0, 2) + elif "gate_and_up_proj" in k: + splits = v.split(2, axis=0) + for k_up, v_new in zip(["up_proj", "gate_proj"], splits): + k_new = k.replace("gate_and_up_proj", k_up) + new_weights[k_new] = v_new + else: + new_weights[k] = v + weights = new_weights + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights for l in range(self.args.num_hidden_layers): diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index f2414660..93cc616e 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,4 +1,4 @@ -# Copyright © 2024 Apple Inc. +# Copyright © 2024-2025 Apple Inc. import math from dataclasses import dataclass @@ -123,17 +123,16 @@ class MambaBlock(nn.Module): self.intermediate_size, self.hidden_size, bias=args.use_bias ) - def ssm_step(self, x, state=None): - A = -mx.exp(self.A_log) + def ssm_step(self, x, A, state=None): D = self.D deltaBC = self.x_proj(x) - delta, B, C = mx.split( - deltaBC, - indices_or_sections=[ - self.time_step_rank, - self.time_step_rank + self.ssm_state_size, - ], - axis=-1, + delta, B, C = map( + self.mixer_norm if self.use_bcdt_rms else lambda x: x, + mx.split( + deltaBC, + [self.time_step_rank, self.time_step_rank + self.ssm_state_size], + axis=-1, + ), ) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) @@ -145,25 +144,40 @@ class MambaBlock(nn.Module): y = y + D * x return y, new_state - def __call__(self, x, cache): + def _process_sequence(self, x, conv_cache, state_cache): B, T, D = x.shape - if cache is None: - cache = [None, None] + xz = self.in_proj(x) + x, z = xz.split(indices_or_sections=2, axis=-1) + + conv_out, new_conv_cache = self.conv1d(x, conv_cache) + x = nn.silu(conv_out) + + A = -mx.exp(self.A_log) outputs = [] + current_state = state_cache + y = [] for t in range(T): - xt = x[:, t, :] - xz = self.in_proj(xt) - x_t, z_t = xz.split(indices_or_sections=2, axis=1) - conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) - x_t = conv_out.squeeze(1) - x_t = nn.silu(x_t) - y_t, cache[1] = self.ssm_step(x_t, cache[1]) - z_t = nn.silu(z_t) - output_t = y_t * z_t - output_t = self.out_proj(output_t) - outputs.append(output_t) - output = mx.stack(outputs, axis=1) + y_t, current_state = self.ssm_step(x[:, t], A, current_state) + y.append(y_t) + y = mx.stack(y, axis=1) + z = self.out_proj(nn.silu(z) * y) + return z, (new_conv_cache, current_state) + + def __call__(self, x, cache): + if cache is None: + conv_cache, state_cache = None, None + else: + conv_cache, state_cache = cache[0], cache[1] + + output, (new_conv_cache, new_state_cache) = self._process_sequence( + x, conv_cache, state_cache + ) + + if isinstance(cache, MambaCache): + cache[0] = new_conv_cache + cache[1] = new_state_cache + return output diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index edddd583..7140c577 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,4 +1,4 @@ -# Copyright © 2023-2024 Apple Inc. +# Copyright © 2023-2025 Apple Inc. from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 1b5bdd77..de9d5324 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -1,5 +1,6 @@ import json from functools import partial +from typing import List from transformers import AutoTokenizer @@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): detokenizer_class, eos_token_ids=eos_token_ids, ) + + +def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List: + removed_bos = sequence if sequence[0] != bos else sequence[1:] + return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 377e7cae..a6f3bd29 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,8 @@ +import itertools import json +import types from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from transformers import PreTrainedTokenizer @@ -34,14 +36,24 @@ class ChatDataset: https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): - self._data = [ - tokenizer.apply_chat_template( - d["messages"], - tools=d.get("tools", None), - ) - for d in data - ] + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + chat_key: str = "messages", + mask_prompt: bool = False, + ): + self._data = [] + for d in data: + messages = d[chat_key] + tools = d.get("tools", None) + tokens = tokenizer.apply_chat_template(messages, tools=tools) + if mask_prompt: + messages = messages[:-1] + offset = len(tokenizer.apply_chat_template(messages, tools=tools)) + self._data.append((tokens, offset)) + else: + self._data.append(tokens) def __getitem__(self, idx: int): return self._data[idx] @@ -63,16 +75,36 @@ class CompletionsDataset: tokenizer: PreTrainedTokenizer, prompt_key: str, completion_key: str, + mask_prompt: bool, ): - self._data = [ - tokenizer.apply_chat_template( + self._data = [] + for d in data: + tokens = tokenizer.apply_chat_template( [ {"role": "user", "content": d[prompt_key]}, {"role": "assistant", "content": d[completion_key]}, ], ) - for d in data - ] + if mask_prompt: + offset = len( + tokenizer.apply_chat_template( + [{"role": "user", "content": d[prompt_key]}] + ) + ) + self._data.append((tokens, offset)) + else: + self._data.append(tokens) + + def __getitem__(self, idx: int): + return self._data[idx] + + def __len__(self): + return len(self._data) + + +class ConcatenatedDataset: + def __init__(self, data: List[Any]): + self._data = list(itertools.chain(*data)) def __getitem__(self, idx: int): return self._data[idx] @@ -84,18 +116,26 @@ class CompletionsDataset: def create_dataset( data, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config, ): - prompt_feature = prompt_feature or "prompt" - completion_feature = completion_feature or "completion" + mask_prompt = getattr(config, "mask_prompt", False) + prompt_feature = getattr(config, "prompt_feature", "prompt") + text_feature = getattr(config, "text_feature", "text") + completion_feature = getattr(config, "completion_feature", "completion") + chat_feature = getattr(config, "chat_feature", "messages") sample = data[0] - if "messages" in sample: - return ChatDataset(data, tokenizer) - elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) - elif "text" in sample: - return Dataset(data, tokenizer) + if prompt_feature in sample and completion_feature in sample: + return CompletionsDataset( + data, tokenizer, prompt_feature, completion_feature, mask_prompt + ) + elif chat_feature in sample: + return ChatDataset( + data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt + ) + elif text_feature in sample: + if mask_prompt: + raise ValueError("Prompt masking not supported for text dataset.") + return Dataset(data, tokenizer, text_key=text_feature) else: raise ValueError( "Unsupported data format, check the supported formats here:\n" @@ -106,15 +146,14 @@ def create_dataset( def load_local_dataset( data_path: Path, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config, ): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(data, tokenizer, config) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -124,8 +163,7 @@ def load_local_dataset( def load_hf_dataset( data_id: str, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config, ): from datasets import exceptions, load_dataset @@ -136,9 +174,7 @@ def load_hf_dataset( train, valid, test = [ ( - create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature - ) + create_dataset(dataset[n], tokenizer, config) if n in dataset.keys() else [] ) @@ -154,42 +190,61 @@ def load_hf_dataset( def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): import datasets - hf_args = args.hf_dataset - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - text_feature = hf_args.get("text_feature") - prompt_feature = hf_args.get("prompt_feature") - completion_feature = hf_args.get("completion_feature") - - def create_hf_dataset(split: str = None): + def create_hf_dataset(dataset_name, config, split, hf_config): ds = datasets.load_dataset( dataset_name, split=split, - **hf_args.get("config", {}), + **hf_config, ) - if prompt_feature and completion_feature: - return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) - elif text_feature: - return Dataset(ds, tokenizer, text_key=text_feature) - else: - raise ValueError( - "Specify either a prompt and completion feature or a text " - "feature for the Hugging Face dataset." + return create_dataset(ds, tokenizer, config) + + dataset_collection = args.hf_dataset + if isinstance(dataset_collection, dict): + dataset_collection = [dataset_collection] + + collection = [] + for ds in dataset_collection: + ds_name = ds["name"] + print(f"Loading Hugging Face dataset {ds_name}.") + ds["mask_prompt"] = getattr(args, "mask_prompt", False) + config = types.SimpleNamespace(**ds) + hf_config = ds.get("config", {}) + if args.train: + train_split = ds.get("train_split", "train[:80%]") + valid_split = ds.get("valid_split", "train[-10%:]") + train = create_hf_dataset( + ds_name, + config, + train_split, + hf_config, ) + valid = create_hf_dataset( + ds_name, + config, + valid_split, + hf_config, + ) + else: + train, valid = [], [] - if args.train: - train_split = hf_args.get("train_split", "train[:80%]") - valid_split = hf_args.get("valid_split", "train[-10%:]") - train = create_hf_dataset(split=train_split) - valid = create_hf_dataset(split=valid_split) - else: - train, valid = [], [] - if args.test: - test = create_hf_dataset(split=hf_args.get("test_split")) - else: - test = [] + if args.test: + test_split = ds.get("test_split") + test = create_hf_dataset( + ds_name, + config, + test_split, + hf_config, + ) + else: + test = [] - return train, valid, test + collection.append((train, valid, test)) + + if len(collection) == 1: + return collection[0] + + # Otherwise concatenate them + return tuple(map(ConcatenatedDataset, zip(*collection))) def load_dataset(args, tokenizer: PreTrainedTokenizer): @@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) - - prompt_feature = getattr(args, "prompt_feature", None) - completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): - train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature - ) + train, valid, test = load_local_dataset(data_path, tokenizer, args) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature - ) + train, valid, test = load_hf_dataset(args.data, tokenizer, args) if args.train and len(train) == 0: raise ValueError( diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 63ca58bb..64e26af8 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -5,13 +5,16 @@ import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import Union +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten +from transformers import PreTrainedTokenizer + +from .datasets import CompletionsDataset def grad_checkpoint(layer): @@ -63,20 +66,30 @@ class TrainingArgs: ) -def default_loss(model, inputs, targets, lengths): +def default_loss(model, batch, lengths): + inputs = batch[:, :-1] + targets = batch[:, 1:] + logits = model(inputs) logits = logits.astype(mx.float32) - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + steps = mx.arange(1, targets.shape[1] + 1) + mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() + ce = nn.losses.cross_entropy(logits, targets) * mask + ntoks = mask.sum() ce = ce.sum() / ntoks return ce, ntoks -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): +def iterate_batches( + dataset, + tokenizer, + batch_size, + max_seq_length, + train=False, +): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) if len(dataset) < batch_size: @@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] + if len(batch[0]) == 2: + batch, offsets = zip(*batch) + else: + offsets = [0] * len(batch) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( @@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) truncated_length # Update lengths to match truncated lengths ) batch = mx.array(batch_arr) - - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + yield batch, mx.array(list(zip(offsets, lengths))) if not train: break @@ -140,8 +156,8 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = 0 - ntokens = 0 + all_losses = mx.array(0.0) + ntokens = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -217,8 +233,8 @@ def train( n_tokens = 0 steps = 0 trained_tokens = 0 + train_time = 0 # Main training loop - start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), iterate_batches( @@ -229,10 +245,11 @@ def train( train=True, ), ): + tic = time.perf_counter() # Report validation loss if needed, the first validation loss # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - stop = time.perf_counter() + tic = time.perf_counter() val_loss = evaluate( model=model, dataset=val_dataset, @@ -243,7 +260,7 @@ def train( max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, ) - val_time = time.perf_counter() - stop + val_time = time.perf_counter() - tic if rank == 0: print( f"Iter {it}: " @@ -260,24 +277,23 @@ def train( } training_callback.on_val_loss_report(val_info) - start = time.perf_counter() + tic = time.perf_counter() lvalue, toks = step(batch) losses += lvalue n_tokens += toks steps += 1 mx.eval(state, losses, n_tokens) + train_time += time.perf_counter() - tic # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: - stop = time.perf_counter() - train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / (stop - start) - tokens_sec = float(n_tokens) / (stop - start) + it_sec = args.steps_per_report / train_time + tokens_sec = float(n_tokens) / train_time trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: @@ -306,7 +322,7 @@ def train( losses = 0 n_tokens = 0 steps = 0 - start = time.perf_counter() + train_time = 0 # Save adapter weights if it % args.steps_per_save == 0: diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index c6c82462..7c08b001 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -89,11 +89,13 @@ def linear_to_lora_layers( "mixtral", "nemotron", "stablelm", + "hunyuan", "qwen2", "qwen2_moe", "phimoe", "gemma", "gemma2", + "granite", "helium", "starcoder2", "cohere", diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0150f1b7..64813123 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -13,7 +13,18 @@ import time from dataclasses import dataclass from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + NamedTuple, + Optional, + Tuple, + Type, + Union, +) import mlx.core as mx import mlx.nn as nn @@ -65,6 +76,7 @@ class GenerationResponse: Args: text (str): The next segment of decoded text. This can be an empty string. token (int): The next token. + from_draft (bool): Whether the token was generated by the draft model. logprobs (mx.array): A vector of log probabilities. prompt_tokens (int): The number of tokens in the prompt. prompt_tps (float): The prompt processing tokens-per-second. @@ -77,6 +89,7 @@ class GenerationResponse: text: str token: int logprobs: mx.array + from_draft: bool prompt_tokens: int prompt_tps: float generation_tokens: int @@ -338,7 +351,7 @@ def speculative_generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, -) -> Generator[Tuple[mx.array, mx.array], None, None]: +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -365,7 +378,8 @@ def speculative_generate_step( when ``kv_bits`` is non-None. Default: ``0``. Yields: - Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, + and a bool indicating if the token was generated by the draft model """ y = prompt @@ -450,12 +464,12 @@ def speculative_generate_step( break n += 1 ntoks += 1 - yield tn, lpn + yield tn, lpn, True if ntoks == max_tokens: break if ntoks < max_tokens: ntoks += 1 - yield tokens[n], logprobs[n] + yield tokens[n], logprobs[n], False if ntoks == max_tokens: break @@ -463,7 +477,7 @@ def speculative_generate_step( y = mx.array([tokens[n]], mx.uint32) draft_y = y - # If we accpeted all the draft tokens, include the last + # If we accepted all the draft tokens, include the last # draft token in the next draft step since it hasn't been # processed yet by the draft model if n == num_draft: @@ -518,6 +532,10 @@ def stream_generate( if draft_model is None: kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) + # from_draft always false for non-speculative generation + token_generator = ( + (token, logprobs, False) for token, logprobs in token_generator + ) else: kwargs.pop("max_kv_size", None) token_generator = speculative_generate_step( @@ -526,7 +544,7 @@ def stream_generate( with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in enumerate(token_generator): + for n, (token, logprobs, from_draft) in enumerate(token_generator): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time @@ -540,6 +558,7 @@ def stream_generate( text=detokenizer.last_segment, token=token, logprobs=logprobs, + from_draft=from_draft, prompt_tokens=prompt.size, prompt_tps=prompt_tps, generation_tokens=n + 1, @@ -553,6 +572,7 @@ def stream_generate( text=detokenizer.last_segment, token=token, logprobs=logprobs, + from_draft=from_draft, prompt_tokens=prompt.size, prompt_tps=prompt_tps, generation_tokens=n + 1, @@ -627,6 +647,7 @@ def load_config(model_path: Path) -> dict: def load_model( model_path: Path, lazy: bool = False, + strict: bool = True, model_config: dict = {}, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, ) -> nn.Module: @@ -638,6 +659,8 @@ def load_model( lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` + strict (bool): Whether or not to raise an exception if weights don't + match. Default: ``True`` model_config (dict, optional): Optional configuration parameters for the model. Defaults to an empty dictionary. get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): @@ -660,7 +683,7 @@ def load_model( # Try weight for back-compat weight_files = glob.glob(str(model_path / "weight*.safetensors")) - if not weight_files: + if not weight_files and strict: logging.error(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}") @@ -694,7 +717,7 @@ def load_model( class_predicate=class_predicate, ) - model.load_weights(list(weights.items())) + model.load_weights(list(weights.items()), strict=strict) if not lazy: mx.eval(model.parameters()) diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py index dd86d277..5edab8bf 100644 --- a/llms/tests/test_datsets.py +++ b/llms/tests/test_datsets.py @@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase): self.assertTrue(isinstance(train, datasets.ChatDataset)) def test_hf(self): + hf_args = { + "name": "billsum", + "prompt_feature": "text", + "completion_feature": "summary", + "train_split": "train[:2%]", + "valid_split": "train[-2%:]", + } args = types.SimpleNamespace( - hf_dataset={ - "name": "billsum", - "prompt_feature": "text", - "completion_feature": "summary", - "train_split": "train[:2%]", - "valid_split": "train[-2%:]", - }, + hf_dataset=hf_args, test=False, train=True, ) @@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase): self.assertTrue(len(valid[0]) > 0) self.assertEqual(len(test), 0) + args = types.SimpleNamespace( + hf_dataset=[hf_args, hf_args], + test=False, + train=True, + ) + train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer) + self.assertEqual(2 * len(train), len(train_double)) + self.assertEqual(2 * len(valid), len(valid_double)) + self.assertEqual(2 * len(test), len(test_double)) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index f2345394..7445a9b9 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -1,17 +1,24 @@ # Copyright © 2024 Apple Inc. import unittest +from typing import List from mlx_lm.sample_utils import make_logits_processors -from mlx_lm.utils import generate, load +from mlx_lm.utils import ( + GenerationResponse, + generate, + load, + make_sampler, + stream_generate, +) class TestGenerate(unittest.TestCase): @classmethod def setUpClass(cls): - HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - cls.model, cls.tokenizer = load(HF_MODEL_PATH) + cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH) def test_generate(self): # Simple test that generation runs @@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase): ) self.assertEqual(len(all_toks), len(init_toks) + 5) + def test_stream_generate_speculative(self): + # Use same model as draft model, this is not a speed test + draft_model, _ = load(self.HF_MODEL_PATH) + + results: List[GenerationResponse] = [] + drafted: List[bool] = [] + + # make a determinate sampler + sampler = make_sampler(temp=0.0) + + for generation_result in stream_generate( + model=self.model, + tokenizer=self.tokenizer, + prompt="hello", + max_tokens=5, + draft_model=draft_model, + num_draft_tokens=2, + sampler=sampler, + ): + drafted.append(generation_result.from_draft) + results.append(generation_result) + + self.assertEqual(len(results), 5) + # since num_draft_tokens is 2 and draft model is the same, the + # first 2 generations should be drafts, the third should come + # from the target model, and last two should be drafts + self.assertEqual(drafted, [True, True, False, True, True]) + if __name__ == "__main__": unittest.main() diff --git a/whisper/mlx_whisper/timing.py b/whisper/mlx_whisper/timing.py index 04915deb..07b81186 100644 --- a/whisper/mlx_whisper/timing.py +++ b/whisper/mlx_whisper/timing.py @@ -134,9 +134,7 @@ def find_alignment( logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) # consider only the logits associated with predicting text sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] - token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( - sampled_logits.dtype - ) + token_probs = mx.softmax(sampled_logits, precise=True, axis=-1) text_token_probs = mx.take_along_axis( token_probs, mx.array(text_tokens)[:, None], axis=1 ).squeeze(1) @@ -144,10 +142,11 @@ def find_alignment( # heads * tokens * frames weights = mx.stack( - [cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] + [cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()] ) weights = weights[:, :, : num_frames // 2] - weights = mx.softmax(weights * qk_scale, axis=-1) + weights = mx.softmax(weights * qk_scale, axis=-1, precise=True) + weights = weights.astype(mx.float32) mean = mx.mean(weights, axis=-2, keepdims=True) std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() weights = (weights - mean) / std diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index 7057679b..e9c2751f 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -195,6 +195,8 @@ def transcribe( seek_points.append(0) if len(seek_points) % 2 == 1: seek_points.append(content_frames) + else: + seek_points[-1] = min(content_frames, seek_points[-1]) seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index 1c2b390e..5c85195c 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module): w = mx.softmax(qk, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out, qk.astype(mx.float32) + return out, qk class ResidualAttentionBlock(nn.Module):