From 52c41b5b5abfdd4ee1c35bd362162b1dc7a62138 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 6 Feb 2025 11:10:58 -0800 Subject: [PATCH 1/7] Fix prompt cache for models without chat template (#1250) * fix deepseek sharding (#1242) * fix prompt cache with no chat template --- llms/mlx_lm/cache_prompt.py | 2 +- llms/mlx_lm/generate.py | 2 +- llms/mlx_lm/models/deepseek_v2.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index c18f1bae..fff64f78 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -152,7 +152,7 @@ def main(): print("Saving...") metadata = {} metadata["model"] = args.model - metadata["chat_template"] = tokenizer.chat_template + metadata["chat_template"] = json.dumps(tokenizer.chat_template) metadata["tokenizer_config"] = json.dumps(tokenizer_config) save_prompt_cache(args.prompt_cache_file, cache, metadata) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0d286c75..e7994750 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -199,7 +199,7 @@ def main(): if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template elif using_cache: - tokenizer.chat_template = metadata["chat_template"] + tokenizer.chat_template = json.loads(metadata["chat_template"]) prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") prompt = sys.stdin.read() if prompt == "-" else prompt diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 3581fcbe..f22b2e3f 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -282,12 +282,12 @@ class MoEGate(nn.Module): if self.topk_method == "group_limited_greedy": bsz, seq_len = x.shape[:2] scores = scores.reshape(bsz, seq_len, self.n_group, -1) - group_scores = scores.max(axis=-1) + group_scores = scores.max(axis=-1, keepdims=True) k = self.n_group - self.topk_group - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] - batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) - seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) - scores[batch_idx, seq_idx, group_idx] = 0.0 + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] + scores = mx.put_along_axis( + scores, group_idx, mx.array(0.0, scores.dtype), axis=-2 + ) scores = scores.reshape(bsz, seq_len, -1) k = self.top_k From 6120a5f3763788f2444e082875b917925b80afa5 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 7 Feb 2025 10:24:57 -0800 Subject: [PATCH 2/7] Faster DSv2/3 expert score computation (#1257) * fix deepseek sharding (#1242) * compile and use put along axis in deep seek routing function --- llms/mlx_lm/models/deepseek_v3.py | 68 +++++++++++++++++++------------ 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index 69ee1be0..2df93d9f 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -271,6 +271,38 @@ class DeepseekV3MLP(nn.Module): return down_proj +@mx.compile +def group_expert_select( + gates, + e_score_correction_bias, + top_k, + n_group, + topk_group, + routed_scaling_factor, + norm_topk_prob, +): + + k = top_k + scores = mx.sigmoid(gates.astype(mx.float32)) + scores = scores + e_score_correction_bias + scores = mx.unflatten(scores, axis=-1, shape=(n_group, -1)) + group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1, keepdims=True) + k = n_group - topk_group + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-2)[..., :k, :] + scores = mx.put_along_axis(scores, group_idx, mx.array(0.0), axis=-2) + scores = mx.flatten(scores, -2, -1) + + k = top_k + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(scores, inds, axis=-1) + if top_k > 1 and norm_topk_prob: + denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 + scores = scores / denominator + scores = scores * routed_scaling_factor + + return inds, scores + + class MoEGate(nn.Module): def __init__(self, config: ModelArgs): super().__init__() @@ -279,38 +311,22 @@ class MoEGate(nn.Module): self.norm_topk_prob = config.norm_topk_prob self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor - self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) + assert config.topk_method == "noaux_tc", "Unsupported topk method." def __call__(self, x): - gates = x @ self.weight.T - - scores = mx.sigmoid(gates.astype(mx.float32)) - - assert self.topk_method == "noaux_tc", "Unsupported topk method." - bsz, seq_len = x.shape[:2] - scores = scores + self.e_score_correction_bias - scores = scores.reshape(bsz, seq_len, self.n_group, -1) - group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1) - k = self.n_group - self.topk_group - group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] - batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) - seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) - scores[batch_idx, seq_idx, group_idx] = 0.0 - scores = scores.reshape(bsz, seq_len, -1) - - k = self.top_k - inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] - scores = mx.take_along_axis(scores, inds, axis=-1) - if self.top_k > 1 and self.norm_topk_prob: - denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 - scores = scores / denominator - scores = scores * self.routed_scaling_factor - - return inds, scores + return group_expert_select( + x @ self.weight.T, + self.e_score_correction_bias, + self.top_k, + self.n_group, + self.topk_group, + self.routed_scaling_factor, + self.norm_topk_prob, + ) class DeepseekV3MoE(nn.Module): From 31611b62d73448cab451f7d5cf72d33a942ae99b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 8 Feb 2025 15:46:15 -0800 Subject: [PATCH 3/7] Add IBM granite model (#1265) * add granite * add thinking option --- llms/mlx_lm/generate.py | 17 ++- llms/mlx_lm/models/granite.py | 195 ++++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 1 + 3 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 llms/mlx_lm/models/granite.py diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index e7994750..d8f97e5e 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -93,6 +93,12 @@ def setup_arg_parser(): action="store_true", help="Use the default chat template", ) + parser.add_argument( + "--chat-template-config", + help="Additional config for `apply_chat_template`. Should be a dictionary of" + " string keys to values represented as a JSON decodable string.", + default=None, + ) parser.add_argument( "--verbose", type=str2bool, @@ -149,7 +155,6 @@ def setup_arg_parser(): def main(): parser = setup_arg_parser() args = parser.parse_args() - mx.random.seed(args.seed) # Load the prompt cache and metadata if a cache file is provided @@ -195,6 +200,10 @@ def main(): for eos_token in args.extra_eos_token: tokenizer.add_eos_token(eos_token) + template_kwargs = {} + if args.chat_template_config is not None: + template_kwargs = json.loads(args.chat_template_config) + if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template @@ -209,8 +218,12 @@ def main(): else: messages = [] messages.append({"role": "user", "content": prompt}) + prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, + tokenize=False, + add_generation_prompt=True, + **template_kwargs, ) # Treat the prompt as a suffix assuming that the prefix is in the diff --git a/llms/mlx_lm/models/granite.py b/llms/mlx_lm/models/granite.py new file mode 100644 index 00000000..43597d99 --- /dev/null +++ b/llms/mlx_lm/models/granite.py @@ -0,0 +1,195 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + logits_scaling: float + attention_multiplier: float + embedding_multiplier: float + residual_multiplier: float + max_position_embeddings: int + num_key_value_heads: int + attention_bias: bool + mlp_bias: bool + rope_theta: float + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.hidden_size // n_heads + + self.scale = args.attention_multiplier + attention_bias = args.attention_bias + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + False, + args.rope_scaling, + args.max_position_embeddings, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.residual_multiplier = args.residual_multiplier + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r * self.residual_multiplier + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r * self.residual_multiplier + return out + + +class GraniteModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.embedding_multiplier = args.embedding_multiplier + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + h = self.embed_tokens(inputs) * self.embedding_multiplier + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = GraniteModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.logits_scaling = args.logits_scaling + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + out = self.model(inputs, mask, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out / self.logits_scaling + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index c0e52731..d86e01dd 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -94,6 +94,7 @@ def linear_to_lora_layers( "phimoe", "gemma", "gemma2", + "granite", "helium", "starcoder2", "cohere", From 1503bd4f550886092b156ec897e633b448bd78bc Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 8 Feb 2025 15:46:47 -0800 Subject: [PATCH 4/7] support hunyuan 7b (#1263) --- llms/mlx_lm/models/hunyuan.py | 37 ++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py index f9dc5652..122cebda 100644 --- a/llms/mlx_lm/models/hunyuan.py +++ b/llms/mlx_lm/models/hunyuan.py @@ -76,7 +76,6 @@ class Attention(nn.Module): head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) if kv_proj: self.k_proj = nn.Linear( @@ -107,7 +106,6 @@ class Attention(nn.Module): B, L, D = x.shape queries = self.q_proj(x) - if kv_states is None: keys, values = self.k_proj(x), self.v_proj(x) kv_states = keys, values @@ -198,7 +196,10 @@ class DecoderLayer(nn.Module): super().__init__() self.hidden_size = args.hidden_size self.self_attn = Attention(kv_proj, args) - self.mlp = MoeBlock(args) + if args.num_experts == 1: + self.mlp = MLP(args.hidden_size, args.intermediate_size) + else: + self.mlp = MoeBlock(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.post_attention_layernorm = nn.RMSNorm( @@ -231,7 +232,10 @@ class HunYuanModel(nn.Module): assert self.vocab_size > 0 self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [ - DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0) + DecoderLayer( + args=args, + kv_proj=(not args.use_cla) or (i % args.cla_share_factor) == 0, + ) for i in range(args.num_hidden_layers) ] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) @@ -251,7 +255,7 @@ class HunYuanModel(nn.Module): cache = [None] * len(self.layers) for i, (layer, c) in enumerate(zip(self.layers, cache)): - if i % self.args.cla_share_factor == 0: + if (not self.args.use_cla) or i % self.args.cla_share_factor == 0: shared_kv_states = None h, shared_kv_states = layer(h, mask, c, shared_kv_states) @@ -275,6 +279,29 @@ class Model(nn.Module): return self.model.embed_tokens.as_linear(out) def sanitize(self, weights): + + if "model.layers.0.mlp.gate_and_up_proj.weight" in weights: + new_weights = {} + D = self.args.hidden_size + n_kv_heads = self.args.num_key_value_heads + n_kv_groups = self.args.num_attention_heads // n_kv_heads + head_dim = D // self.args.num_attention_heads + for k, v in weights.items(): + if "qkv_proj" in k: + v = v.reshape(n_kv_heads, n_kv_groups + 2, head_dim, -1) + splits = v.split([n_kv_groups, n_kv_groups + 1], axis=1) + for k_up, v_new in zip(["q_proj", "k_proj", "v_proj"], splits): + k_new = k.replace("qkv_proj", k_up) + new_weights[k_new] = mx.flatten(v_new, 0, 2) + elif "gate_and_up_proj" in k: + splits = v.split(2, axis=0) + for k_up, v_new in zip(["up_proj", "gate_proj"], splits): + k_new = k.replace("gate_and_up_proj", k_up) + new_weights[k_new] = v_new + else: + new_weights[k] = v + weights = new_weights + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: return weights for l in range(self.args.num_hidden_layers): From f58c7de9017b54b044703f88787e6c679db9ec7e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 8 Feb 2025 15:47:00 -0800 Subject: [PATCH 5/7] Some improvements to speedup alignment computation in MLX Whisper (#1259) * some improvements to speedup alignment computation in MLX Whisper * fix alignment --- whisper/mlx_whisper/timing.py | 9 ++++----- whisper/mlx_whisper/whisper.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/whisper/mlx_whisper/timing.py b/whisper/mlx_whisper/timing.py index 04915deb..07b81186 100644 --- a/whisper/mlx_whisper/timing.py +++ b/whisper/mlx_whisper/timing.py @@ -134,9 +134,7 @@ def find_alignment( logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) # consider only the logits associated with predicting text sampled_logits = logits[0][len(tokenizer.sot_sequence) : -2, : tokenizer.eot] - token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype( - sampled_logits.dtype - ) + token_probs = mx.softmax(sampled_logits, precise=True, axis=-1) text_token_probs = mx.take_along_axis( token_probs, mx.array(text_tokens)[:, None], axis=1 ).squeeze(1) @@ -144,10 +142,11 @@ def find_alignment( # heads * tokens * frames weights = mx.stack( - [cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads] + [cross_qk[_l][0, _h] for _l, _h in model.alignment_heads.tolist()] ) weights = weights[:, :, : num_frames // 2] - weights = mx.softmax(weights * qk_scale, axis=-1) + weights = mx.softmax(weights * qk_scale, axis=-1, precise=True) + weights = weights.astype(mx.float32) mean = mx.mean(weights, axis=-2, keepdims=True) std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() weights = (weights - mean) / std diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index 1c2b390e..5c85195c 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -84,7 +84,7 @@ class MultiHeadAttention(nn.Module): w = mx.softmax(qk, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out, qk.astype(mx.float32) + return out, qk class ResidualAttentionBlock(nn.Module): From 1ced1b00ca9c2457fcbf0e54ffcffe58f53fb4fd Mon Sep 17 00:00:00 2001 From: Sri Harsha Pamu Date: Sun, 9 Feb 2025 11:39:11 -0800 Subject: [PATCH 6/7] rm temp argument (#1267) --- llms/mlx_lm/examples/chat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index 4a7020f1..dcd90b67 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -23,7 +23,6 @@ response = generate( tokenizer, prompt=prompt, verbose=True, - temp=0.0, prompt_cache=prompt_cache, ) From 5865899c81d35ea48c6b69071d7fe61a46880d30 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 9 Feb 2025 23:12:34 -0500 Subject: [PATCH 7/7] Completion only fine-tuning of instruction models with collections of HF datasets (#1103) - Optional completion only fine-tuning with `--mask-prompt` - Collections of Hugging Face datasets --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 26 ++++- llms/mlx_lm/lora.py | 9 ++ llms/mlx_lm/tokenizer_utils.py | 6 ++ llms/mlx_lm/tuner/datasets.py | 186 +++++++++++++++++++++------------ llms/mlx_lm/tuner/trainer.py | 32 ++++-- llms/tests/test_datsets.py | 25 +++-- 6 files changed, 199 insertions(+), 85 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 9eac9d7f..e863abc4 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -76,6 +76,14 @@ You can specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with `--resume-adapter-file `. +#### Prompt Masking + +The default training computes a loss for every token in the sample. You can +ignore the prompt and compute loss for just the completion by passing +`--mask-prompt`. Note this is only supported for `chat` and `completion` +datasets. For `chat` datasets the final message in the message list is +considered the completion. See the [dataset section](#Data) for more details. + ### Evaluate To compute test set perplexity use: @@ -290,11 +298,27 @@ hf_dataset: - Use `prompt_feature` and `completion_feature` to specify keys for a `completions` dataset. Use `text_feature` to specify the key for a `text` - dataset. + dataset. Use `chat_feature` to specify the key for a chat dataset. - To specify the train, valid, or test splits, set the corresponding `{train,valid,test}_split` argument. +You can specify a list of Hugging Face datasets with a list of records each +with the same structure as above. For example: + +```yaml +hf_dataset: + - name: "Open-Orca/OpenOrca" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + prompt_feature: "question" + completion_feature: "response" + - name: "trl-lib/ultrafeedback_binarized" + train_split: "train[:90%]" + valid_split: "train[-10%:]" + chat_feature: "chosen" +``` + - Arguments specified in `config` will be passed as keyword arguments to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 43f508c3..abc5dfa9 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -94,6 +94,14 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) + + parser.add_argument( + "--mask-prompt", + action="store_true", + help="Mask the prompt in the loss when training", + default=False, + ) + parser.add_argument( "--num-layers", type=int, @@ -219,6 +227,7 @@ def train_model( build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate ) ) + # Train model train( model=model, diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 1b5bdd77..de9d5324 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -1,5 +1,6 @@ import json from functools import partial +from typing import List from transformers import AutoTokenizer @@ -368,3 +369,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): detokenizer_class, eos_token_ids=eos_token_ids, ) + + +def no_bos_or_eos(sequence: List, bos: int, eos: int) -> List: + removed_bos = sequence if sequence[0] != bos else sequence[1:] + return removed_bos[:-1] if removed_bos[-1] == eos else removed_bos diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 377e7cae..a6f3bd29 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,8 @@ +import itertools import json +import types from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from transformers import PreTrainedTokenizer @@ -34,14 +36,24 @@ class ChatDataset: https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): - self._data = [ - tokenizer.apply_chat_template( - d["messages"], - tools=d.get("tools", None), - ) - for d in data - ] + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + chat_key: str = "messages", + mask_prompt: bool = False, + ): + self._data = [] + for d in data: + messages = d[chat_key] + tools = d.get("tools", None) + tokens = tokenizer.apply_chat_template(messages, tools=tools) + if mask_prompt: + messages = messages[:-1] + offset = len(tokenizer.apply_chat_template(messages, tools=tools)) + self._data.append((tokens, offset)) + else: + self._data.append(tokens) def __getitem__(self, idx: int): return self._data[idx] @@ -63,16 +75,36 @@ class CompletionsDataset: tokenizer: PreTrainedTokenizer, prompt_key: str, completion_key: str, + mask_prompt: bool, ): - self._data = [ - tokenizer.apply_chat_template( + self._data = [] + for d in data: + tokens = tokenizer.apply_chat_template( [ {"role": "user", "content": d[prompt_key]}, {"role": "assistant", "content": d[completion_key]}, ], ) - for d in data - ] + if mask_prompt: + offset = len( + tokenizer.apply_chat_template( + [{"role": "user", "content": d[prompt_key]}] + ) + ) + self._data.append((tokens, offset)) + else: + self._data.append(tokens) + + def __getitem__(self, idx: int): + return self._data[idx] + + def __len__(self): + return len(self._data) + + +class ConcatenatedDataset: + def __init__(self, data: List[Any]): + self._data = list(itertools.chain(*data)) def __getitem__(self, idx: int): return self._data[idx] @@ -84,18 +116,26 @@ class CompletionsDataset: def create_dataset( data, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config, ): - prompt_feature = prompt_feature or "prompt" - completion_feature = completion_feature or "completion" + mask_prompt = getattr(config, "mask_prompt", False) + prompt_feature = getattr(config, "prompt_feature", "prompt") + text_feature = getattr(config, "text_feature", "text") + completion_feature = getattr(config, "completion_feature", "completion") + chat_feature = getattr(config, "chat_feature", "messages") sample = data[0] - if "messages" in sample: - return ChatDataset(data, tokenizer) - elif prompt_feature in sample and completion_feature in sample: - return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) - elif "text" in sample: - return Dataset(data, tokenizer) + if prompt_feature in sample and completion_feature in sample: + return CompletionsDataset( + data, tokenizer, prompt_feature, completion_feature, mask_prompt + ) + elif chat_feature in sample: + return ChatDataset( + data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt + ) + elif text_feature in sample: + if mask_prompt: + raise ValueError("Prompt masking not supported for text dataset.") + return Dataset(data, tokenizer, text_key=text_feature) else: raise ValueError( "Unsupported data format, check the supported formats here:\n" @@ -106,15 +146,14 @@ def create_dataset( def load_local_dataset( data_path: Path, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config, ): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, prompt_feature, completion_feature) + return create_dataset(data, tokenizer, config) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] @@ -124,8 +163,7 @@ def load_local_dataset( def load_hf_dataset( data_id: str, tokenizer: PreTrainedTokenizer, - prompt_feature: Optional[str] = None, - completion_feature: Optional[str] = None, + config, ): from datasets import exceptions, load_dataset @@ -136,9 +174,7 @@ def load_hf_dataset( train, valid, test = [ ( - create_dataset( - dataset[n], tokenizer, prompt_feature, completion_feature - ) + create_dataset(dataset[n], tokenizer, config) if n in dataset.keys() else [] ) @@ -154,42 +190,61 @@ def load_hf_dataset( def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): import datasets - hf_args = args.hf_dataset - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - text_feature = hf_args.get("text_feature") - prompt_feature = hf_args.get("prompt_feature") - completion_feature = hf_args.get("completion_feature") - - def create_hf_dataset(split: str = None): + def create_hf_dataset(dataset_name, config, split, hf_config): ds = datasets.load_dataset( dataset_name, split=split, - **hf_args.get("config", {}), + **hf_config, ) - if prompt_feature and completion_feature: - return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) - elif text_feature: - return Dataset(ds, tokenizer, text_key=text_feature) - else: - raise ValueError( - "Specify either a prompt and completion feature or a text " - "feature for the Hugging Face dataset." + return create_dataset(ds, tokenizer, config) + + dataset_collection = args.hf_dataset + if isinstance(dataset_collection, dict): + dataset_collection = [dataset_collection] + + collection = [] + for ds in dataset_collection: + ds_name = ds["name"] + print(f"Loading Hugging Face dataset {ds_name}.") + ds["mask_prompt"] = getattr(args, "mask_prompt", False) + config = types.SimpleNamespace(**ds) + hf_config = ds.get("config", {}) + if args.train: + train_split = ds.get("train_split", "train[:80%]") + valid_split = ds.get("valid_split", "train[-10%:]") + train = create_hf_dataset( + ds_name, + config, + train_split, + hf_config, ) + valid = create_hf_dataset( + ds_name, + config, + valid_split, + hf_config, + ) + else: + train, valid = [], [] - if args.train: - train_split = hf_args.get("train_split", "train[:80%]") - valid_split = hf_args.get("valid_split", "train[-10%:]") - train = create_hf_dataset(split=train_split) - valid = create_hf_dataset(split=valid_split) - else: - train, valid = [], [] - if args.test: - test = create_hf_dataset(split=hf_args.get("test_split")) - else: - test = [] + if args.test: + test_split = ds.get("test_split") + test = create_hf_dataset( + ds_name, + config, + test_split, + hf_config, + ) + else: + test = [] - return train, valid, test + collection.append((train, valid, test)) + + if len(collection) == 1: + return collection[0] + + # Otherwise concatenate them + return tuple(map(ConcatenatedDataset, zip(*collection))) def load_dataset(args, tokenizer: PreTrainedTokenizer): @@ -197,18 +252,11 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) - - prompt_feature = getattr(args, "prompt_feature", None) - completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): - train, valid, test = load_local_dataset( - data_path, tokenizer, prompt_feature, completion_feature - ) + train, valid, test = load_local_dataset(data_path, tokenizer, args) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset( - args.data, tokenizer, prompt_feature, completion_feature - ) + train, valid, test = load_hf_dataset(args.data, tokenizer, args) if args.train and len(train) == 0: raise ValueError( diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index bf84d066..d675f9b6 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -5,13 +5,16 @@ import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import Union +from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten +from transformers import PreTrainedTokenizer + +from .datasets import CompletionsDataset def grad_checkpoint(layer): @@ -63,20 +66,30 @@ class TrainingArgs: ) -def default_loss(model, inputs, targets, lengths): +def default_loss(model, batch, lengths): + inputs = batch[:, :-1] + targets = batch[:, 1:] + logits = model(inputs) logits = logits.astype(mx.float32) - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + steps = mx.arange(1, targets.shape[1] + 1) + mask = mx.logical_and(steps >= lengths[:, 0:1], steps <= lengths[:, 1:]) - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() + ce = nn.losses.cross_entropy(logits, targets) * mask + ntoks = mask.sum() ce = ce.sum() / ntoks return ce, ntoks -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): +def iterate_batches( + dataset, + tokenizer, + batch_size, + max_seq_length, + train=False, +): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) if len(dataset) < batch_size: @@ -101,6 +114,10 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) indices = np.random.permutation(len(batch_idx)) for i in indices: batch = [dataset[j] for j in batch_idx[i]] + if len(batch[0]) == 2: + batch, offsets = zip(*batch) + else: + offsets = [0] * len(batch) lengths = [len(x) for x in batch] if max(lengths) > max_seq_length: print( @@ -123,8 +140,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) truncated_length # Update lengths to match truncated lengths ) batch = mx.array(batch_arr) - - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + yield batch, mx.array(list(zip(offsets, lengths))) if not train: break diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py index dd86d277..5edab8bf 100644 --- a/llms/tests/test_datsets.py +++ b/llms/tests/test_datsets.py @@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase): self.assertTrue(isinstance(train, datasets.ChatDataset)) def test_hf(self): + hf_args = { + "name": "billsum", + "prompt_feature": "text", + "completion_feature": "summary", + "train_split": "train[:2%]", + "valid_split": "train[-2%:]", + } args = types.SimpleNamespace( - hf_dataset={ - "name": "billsum", - "prompt_feature": "text", - "completion_feature": "summary", - "train_split": "train[:2%]", - "valid_split": "train[-2%:]", - }, + hf_dataset=hf_args, test=False, train=True, ) @@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase): self.assertTrue(len(valid[0]) > 0) self.assertEqual(len(test), 0) + args = types.SimpleNamespace( + hf_dataset=[hf_args, hf_args], + test=False, + train=True, + ) + train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer) + self.assertEqual(2 * len(train), len(train_double)) + self.assertEqual(2 * len(valid), len(valid_double)) + self.assertEqual(2 * len(test), len(test_double)) + if __name__ == "__main__": unittest.main()