From db109184b7f23ce3166c6cfd4682b092b4bdfbb6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 18 Dec 2024 18:46:50 -0800 Subject: [PATCH 01/15] Fix no template prompt + top_k sampling (#1166) * fix no template prompt * add top_k sampling * fix chinese --- llms/mlx_lm/generate.py | 12 +++--------- llms/mlx_lm/sample_utils.py | 34 ++++++++++++++++++++++++++++++++- llms/tests/test_sample_utils.py | 23 +++++++++++++++++++++- 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 84dc63ca..afb1394e 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -1,7 +1,6 @@ # Copyright © 2023-2024 Apple Inc. import argparse -import codecs import json import sys @@ -189,8 +188,8 @@ def main(): elif using_cache: tokenizer.chat_template = metadata["chat_template"] - prompt = codecs.decode(args.prompt, "unicode_escape") - + prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") + prompt = sys.stdin.read() if prompt == "-" else prompt if not args.ignore_chat_template and ( hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None @@ -199,12 +198,7 @@ def main(): messages = [{"role": "system", "content": args.system_prompt}] else: messages = [] - messages.append( - { - "role": "user", - "content": sys.stdin.read() if prompt == "-" else prompt, - } - ) + messages.append({"role": "user", "content": prompt}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index c77f056a..c48a32cf 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -12,6 +12,7 @@ def make_sampler( top_p: float = 0.0, min_p: float = 0.0, min_tokens_to_keep: int = 1, + top_k: int = -1, ) -> Callable[mx.array, mx.array]: """ Make a sampler function for use with ``generate_step``. @@ -25,6 +26,8 @@ def make_sampler( probability) that a token probability must have to be considered. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered by min_p sampling. + top_k (int, optional): The top k tokens ranked by probability to constrain + the sampling to. Returns: Callable[mx.array, mx.array]: @@ -36,6 +39,8 @@ def make_sampler( return lambda x: top_p_sampling(x, top_p, temp) elif min_p != 0.0: return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) + elif top_k > 0: + return lambda x: top_k_sampling(x, top_k, temp) else: return lambda x: categorical_sampling(x, temp) @@ -79,6 +84,33 @@ def make_logits_processors( return logits_processors +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def top_k_sampling( + logprobs: mx.array, + top_k: int, + temperature=1.0, +) -> mx.array: + """ + Sample from only the top K tokens ranked by probability. + + Args: + logprobs: A vector of log probabilities. + top_k (int): Top k tokens to sample from. + """ + vocab_size = logprobs.shape[-1] + if not isinstance(top_k, int) or not (0 < top_k < vocab_size): + raise ValueError( + f"`top_k` has to be an integer in the (0, {vocab_size}] interval," + f" but is {top_k}." + ) + logprobs = logprobs * (1 / temperature) + mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:] + masked_logprobs = mx.put_along_axis( + logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1 + ) + return mx.random.categorical(masked_logprobs, axis=-1) + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( logprobs: mx.array, @@ -87,7 +119,7 @@ def min_p_sampling( temperature=1.0, ) -> mx.array: """ - Apply min-p sampling to the logits. + Apply min-p sampling to the logprobs. Min-p keeps all tokens that are above a minimum probability, scaled by the probability of the most likely token. As a result, the filter is more diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index ebc90ce8..c45fa443 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -1,7 +1,7 @@ import unittest import mlx.core as mx -from mlx_lm.sample_utils import min_p_sampling, top_p_sampling +from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling class TestSampleUtils(unittest.TestCase): @@ -42,6 +42,27 @@ class TestSampleUtils(unittest.TestCase): token = min_p_sampling(logits, 0.05) self.assertTrue(token in (0, 3)) + def test_top_k_sampling(self): + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + + token = top_k_sampling(logits, 1).item() + self.assertEqual(token, 0) + + probs = mx.array([0.5, 0.0, 0.0, 0.5])[None] + tokens = set() + for _ in range(100): + token = top_k_sampling(logits, 2) + tokens.add(token.item()) + self.assertEqual(tokens, {0, 3}) + + # Batch mode works + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + logits = mx.log(probs) + + tokens = top_k_sampling(logits, 1) + self.assertEqual(tokens.tolist(), [0, 1]) + if __name__ == "__main__": unittest.main() From d4ef909d4ab44d9f8cf89f5baa8a433d76d7d6b1 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 18 Dec 2024 19:43:52 -0800 Subject: [PATCH 02/15] Length masking for batch inputs (#1173) * length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix --- llms/mlx_lm/models/base.py | 10 +++++++++- llms/mlx_lm/models/cohere.py | 7 +++++-- llms/mlx_lm/models/cohere2.py | 14 ++++++-------- llms/mlx_lm/models/dbrx.py | 7 +++++-- llms/mlx_lm/models/deepseek.py | 7 +++++-- llms/mlx_lm/models/deepseek_v2.py | 8 ++++++-- llms/mlx_lm/models/exaone.py | 7 +++++-- llms/mlx_lm/models/gemma.py | 7 +++++-- llms/mlx_lm/models/gemma2.py | 7 +++++-- llms/mlx_lm/models/gpt2.py | 7 +++++-- llms/mlx_lm/models/gpt_bigcode.py | 7 +++++-- llms/mlx_lm/models/gpt_neox.py | 7 +++++-- llms/mlx_lm/models/hunyuan.py | 7 +++++-- llms/mlx_lm/models/internlm2.py | 7 +++++-- llms/mlx_lm/models/llama.py | 7 +++++-- llms/mlx_lm/models/minicpm.py | 7 +++++-- llms/mlx_lm/models/mixtral.py | 7 +++++-- llms/mlx_lm/models/nemotron.py | 7 +++++-- llms/mlx_lm/models/olmo.py | 10 +++++++--- llms/mlx_lm/models/olmo2.py | 7 +++++-- llms/mlx_lm/models/openelm.py | 7 +++++-- llms/mlx_lm/models/phi.py | 8 +++++--- llms/mlx_lm/models/phi3.py | 7 +++++-- llms/mlx_lm/models/phi3small.py | 7 +++++-- llms/mlx_lm/models/phimoe.py | 7 +++++-- llms/mlx_lm/models/phixtral.py | 4 +++- llms/mlx_lm/models/plamo.py | 7 +++++-- llms/mlx_lm/models/qwen.py | 3 ++- llms/mlx_lm/models/qwen2.py | 7 +++++-- llms/mlx_lm/models/qwen2_moe.py | 7 +++++-- llms/mlx_lm/models/recurrent_gemma.py | 8 +++++--- llms/mlx_lm/models/stablelm.py | 5 ++++- llms/mlx_lm/models/starcoder2.py | 7 +++++-- llms/tests/test_models.py | 25 ++++++++++++++++++++++++- 34 files changed, 191 insertions(+), 72 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index f02f49b1..ad7a4a65 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -23,7 +23,12 @@ class BaseModelArgs: ) -def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): +def create_causal_mask( + N: int, + offset: int = 0, + window_size: Optional[int] = None, + lengths: Optional[mx.array] = None, +): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds linds = linds[:, None] @@ -31,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non mask = linds < rinds if window_size is not None: mask = mask | (linds > rinds + window_size) + if lengths is not None: + lengths = lengths[:, None, None, None] + mask = mask | (rinds >= lengths) return mask * -1e9 diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 7e002b0c..b2d16dd7 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -155,11 +155,13 @@ class CohereModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -180,9 +182,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index fcb4061b..ec0e9276 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import KVCache, RotatingKVCache @@ -151,16 +151,13 @@ class CohereModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - T = h.shape[1] - if T > 1: - offset = cache[0].offset if cache else 0 - mask = create_causal_mask(T, offset).astype(h.dtype) - else: - mask = None + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -181,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 7be274cc..886b5630 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -197,11 +197,13 @@ class DBRX(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -223,9 +225,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index b7b24dba..ffc30c36 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -211,9 +211,11 @@ class DeepseekModel(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -236,8 +238,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 444813b9..9027da7e 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -370,9 +370,12 @@ class DeepseekV2Model(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -395,8 +398,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py index eaed5dd8..ee3ed1e8 100644 --- a/llms/mlx_lm/models/exaone.py +++ b/llms/mlx_lm/models/exaone.py @@ -123,10 +123,12 @@ class ExaoneModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.h) @@ -149,9 +151,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 3f384c3f..0860ddeb 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -138,12 +138,14 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -164,9 +166,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) return out diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 64951ae4..321a58ff 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -160,12 +160,14 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -187,9 +189,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = mx.tanh(out / self.final_logit_softcapping) out = out * self.final_logit_softcapping diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 52076a34..5b277734 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -126,6 +126,7 @@ class GPT2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape @@ -138,7 +139,8 @@ class GPT2Model(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -159,9 +161,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.wte.as_linear(out) return out diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 23e86e20..8415c59e 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): B, L = inputs.shape @@ -149,7 +150,8 @@ class GPTBigCodeModel(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -172,9 +174,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index ccb0b28b..5e124a67 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -146,13 +146,15 @@ class GPTNeoXModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape hidden_states = self.embed_in(inputs) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -176,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return out def sanitize(self, weights): diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py index b098c20d..f9dc5652 100644 --- a/llms/mlx_lm/models/hunyuan.py +++ b/llms/mlx_lm/models/hunyuan.py @@ -239,11 +239,13 @@ class HunYuanModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -266,9 +268,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.model.embed_tokens.as_linear(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index f5ce057e..28a095e1 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -193,11 +193,13 @@ class InternLM2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.tok_embeddings(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -220,9 +222,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.tok_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 290cb83e..7b452ea4 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -155,11 +155,13 @@ class LlamaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -182,9 +184,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 907beb2a..edddd583 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -158,11 +158,13 @@ class MiniCPMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) * self.args.scale_emb - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -186,9 +188,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if not self.args.tie_word_embeddings: out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index dd94d1f4..0afd1235 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -162,11 +162,13 @@ class MixtralModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -188,9 +190,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index f73c0277..eabfac8c 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -176,11 +176,13 @@ class NemotronModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -203,9 +205,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 3627df06..4273b0ec 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -124,11 +124,13 @@ class Transformer(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -152,9 +154,10 @@ class OlmoModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.transformer(inputs, cache) + return self.transformer(inputs, mask, cache) class Model(nn.Module): @@ -167,9 +170,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.model(inputs, cache) + return self.model(inputs, mask, cache) @property def layers(self): diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py index 64d7e116..510ff882 100644 --- a/llms/mlx_lm/models/olmo2.py +++ b/llms/mlx_lm/models/olmo2.py @@ -163,10 +163,12 @@ class LlamaModel(nn.Module): self, inputs: mx.array, cache=None, + mask=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -190,8 +192,9 @@ class Model(nn.Module): self, inputs: mx.array, cache=None, + mask=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 408802f4..504fe95c 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -178,11 +178,13 @@ class OpenELMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.token_embeddings(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -205,9 +207,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.share_input_output_layers: out = self.transformer.token_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 510025ea..e9724691 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -143,10 +143,11 @@ class PhiModel(nn.Module): config.hidden_size, eps=config.layer_norm_eps ) - def __call__(self, x, cache): + def __call__(self, x, mask, cache): x = self.embed_tokens(x) - mask = create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.layers) @@ -167,9 +168,10 @@ class Model(nn.Module): def __call__( self, x: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: - y = self.model(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) @property diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index ee6efc49..d1c21e25 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -168,11 +168,13 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -194,9 +196,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 53e1a638..cd566eec 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -258,13 +258,15 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) if self.mup_embedding_multiplier: h = self.mup_embedding_multiplier * h - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -290,9 +292,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) if self.mup_width_multiplier: out = out / self.mup_width_multiplier diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index f42a6dd0..bddcb128 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -155,11 +155,13 @@ class PhiMoEModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -181,9 +183,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 42d647b0..5477c2c0 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -175,7 +175,9 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index c8e5bf50..9107daad 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -174,10 +174,12 @@ class PlamoModel(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None for _ in range(len(self.layers.layers))] @@ -202,8 +204,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 8145a890..ec8a0199 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -123,7 +123,8 @@ class QwenModel(nn.Module): def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) - mask = create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index fac59d78..381767c4 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -149,11 +149,13 @@ class Qwen2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -176,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 167fc5dd..c6aba622 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -187,11 +187,13 @@ class Qwen2MoeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -213,9 +215,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 49e4bb8f..ad07d925 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -389,6 +389,7 @@ class Griffin(nn.Module): def __call__( self, tokens, + mask: mx.array = None, cache=None, ): x = self.embed_tokens(tokens) @@ -402,7 +403,8 @@ class Griffin(nn.Module): if block.temporal_block_type != "recurrent": mask_cache = [cache[i]] - mask = create_attention_mask(x, mask_cache) + if mask is None: + mask = create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) @@ -418,12 +420,12 @@ class Model(nn.Module): self.model_type = config.model_type self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - def __call__(self, tokens: mx.array, cache=None) -> mx.array: + def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array: """ Args: tokens: Sequence of input tokens. """ - logits = self.model(tokens, cache=cache) + logits = self.model(tokens, mask=mask, cache=cache) if "lm_head" in self: logits = self.lm_head(logits) else: diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 482bb324..0bbc2ca4 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -199,7 +199,10 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index d7e626f2..71c397f6 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -125,11 +125,13 @@ class Starcoder2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -152,9 +154,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 3097c522..7b4376bb 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_map from mlx_lm.models import rope_utils +from mlx_lm.models.base import create_causal_mask from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache @@ -128,6 +129,22 @@ class TestModels(unittest.TestCase): self.assertEqual(cache.offset, 22) self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def test_causal_mask_lengths(self): + mx.random.seed(8) + B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2) + lengths = mx.array([1, 2, 3, 1]) + q = mx.random.uniform(shape=(B, N_q, T_q, D)) + k = mx.random.uniform(shape=(B, N_kv, T_kv, D)) + v = k + mask = create_causal_mask(T_q, 0, lengths=lengths) + + out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + q[1, :, 2:] = mx.ones_like(q[1, :, 2:]) + k[1, :, 2:] = mx.ones_like(k[1, :, 2:]) + v[1, :, 2:] = mx.ones_like(v[1, :, 2:]) + out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2])) + def test_rope(self): rope = rope_utils.initialize_rope(32, base=100, traditional=False) self.assertTrue(isinstance(rope, nn.RoPE)) @@ -162,10 +179,16 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.dtype, t) cache = make_prompt_cache(model) - outputs = model(inputs, cache) + outputs = model(inputs, cache=cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) + if model_type != "mamba": + mask = create_causal_mask(inputs.shape[1], 0).astype(t) + outputs = model(inputs, mask=mask) + self.assertEqual(outputs.shape, (1, 2, vocab_size)) + self.assertEqual(outputs.dtype, t) + outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache) self.assertEqual(outputs.shape, (1, 1, vocab_size)) self.assertEqual(outputs.dtype, t) From 3a58c361096e5be7a927e7719c5ef66bace9a8ab Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 1 Jan 2025 16:25:57 +0100 Subject: [PATCH 03/15] Improvements to mlx_lm.manage (#1178) * improvements to manage. Default value is N and size added to deletion confirmation. * Fixing case for no case * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/manage.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py index bb5c3a09..9827f3dc 100644 --- a/llms/mlx_lm/manage.py +++ b/llms/mlx_lm/manage.py @@ -6,19 +6,18 @@ from transformers.commands.user import tabulate def ask_for_confirmation(message: str) -> bool: + """Ask user for confirmation with Y/N prompt. + Returns True for Y/yes, False for N/no/empty.""" y = ("y", "yes", "1") - n = ("n", "no", "0") - all_values = y + n + ("",) - full_message = f"{message} (Y/n) " + n = ("n", "no", "0", "") + full_message = f"{message} (y/n) " while True: answer = input(full_message).lower() - if answer == "": - return False if answer in y: return True if answer in n: return False - print(f"Invalid input. Must be one of {all_values}") + print(f"Invalid input. Must be one of: yes/no/y/n or empty for no") def main(): @@ -43,9 +42,7 @@ def main(): args = parser.parse_args() if args.scan: - print( - "Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".' - ) + print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".') hf_cache_info = scan_cache_dir() print( tabulate( @@ -86,35 +83,41 @@ def main(): if args.pattern in repo.repo_id ] if repos: + print("\nFound the following models:") print( tabulate( rows=[ [ repo.repo_id, + repo.size_on_disk_str, # Added size information str(repo.repo_path), ] for repo in repos ], headers=[ "REPO ID", + "SIZE", # Added size header "LOCAL PATH", ], ) ) - confirmed = ask_for_confirmation(f"Confirm deletion ?") + confirmed = ask_for_confirmation( + "\nAre you sure you want to delete these models?" + ) if confirmed: for model_info in repos: + print(f"\nDeleting {model_info.repo_id}...") for revision in sorted( model_info.revisions, key=lambda revision: revision.commit_hash ): strategy = hf_cache_info.delete_revisions(revision.commit_hash) strategy.execute() - print("Model(s) deleted.") + print("\nModel(s) deleted successfully.") else: - print("Deletion is cancelled. Do nothing.") + print("\nDeletion cancelled - no changes made.") else: - print(f"No models found.") + print(f'No models found matching pattern "{args.pattern}"') if __name__ == "__main__": From c4833a2f55c4553f71b16a412a6eb6d2f1427380 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 3 Jan 2025 10:50:59 -0800 Subject: [PATCH 04/15] fix encoding with special tokens + chat template (#1189) --- llms/README.md | 4 +- llms/mlx_lm/cache_prompt.py | 20 ++---- llms/mlx_lm/chat.py | 4 +- llms/mlx_lm/evaluate.py | 28 ++++++--- llms/mlx_lm/examples/chat.py | 8 +-- llms/mlx_lm/examples/generate_response.py | 2 +- llms/mlx_lm/generate.py | 9 +-- llms/mlx_lm/lora.py | 2 + llms/mlx_lm/server.py | 6 +- llms/mlx_lm/tuner/datasets.py | 77 ++++++++++++----------- llms/mlx_lm/tuner/trainer.py | 8 +-- llms/mlx_lm/utils.py | 19 +++--- llms/tests/test_datsets.py | 5 +- 13 files changed, 95 insertions(+), 97 deletions(-) diff --git a/llms/README.md b/llms/README.md index 4fff4207..e943ed69 100644 --- a/llms/README.md +++ b/llms/README.md @@ -58,7 +58,7 @@ prompt = "Write a story about Einstein" messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) text = generate(model, tokenizer, prompt=prompt, verbose=True) @@ -115,7 +115,7 @@ prompt = "Write a story about Einstein" messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) for response in stream_generate(model, tokenizer, prompt, max_tokens=512): diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 9d7d1603..c18f1bae 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -110,29 +110,17 @@ def main(): if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template - if not args.ignore_chat_template and ( - hasattr(tokenizer, "apply_chat_template") - and tokenizer.chat_template is not None - ): + if not args.ignore_chat_template and tokenizer.chat_template is not None: messages = [{"role": "user", "content": args.prompt}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=False, continue_final_message=True ) - # Treat the prompt as a prefix assuming that the suffix will be - # provided at generation time. - test_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], - tokenize=False, - add_generation_prompt=True, - ) - n = len(test_prompt) - test_prompt.index("") - len("") - prompt = prompt[:-n] else: - prompt = args.prompt + prompt = tokenizer.encode(args.prompt) cache = make_prompt_cache(model, args.max_kv_size) - y = mx.array(tokenizer.encode(prompt)) + y = mx.array(prompt) # Process the prompt start = time.time() diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 5a8245ef..e52ad10d 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -72,9 +72,7 @@ def main(): if query == "q": break messages = [{"role": "user", "content": query}] - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) for response in stream_generate( model, tokenizer, diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index c4b15748..bf7bf4d4 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -1,4 +1,8 @@ -# Adapted from a PyTorch implementation by David Grangier +# Copyright © 2024 Apple Inc. + +""" +Adapted from a PyTorch implementation by David Grangier +""" import argparse import json @@ -6,7 +10,7 @@ import logging import os from importlib.metadata import version from pathlib import Path -from typing import Optional +from typing import Optional, Union import lm_eval import mlx.core as mx @@ -277,19 +281,19 @@ class MLXLM(LM): assert "until" in keys untils = [x["until"] for x in options] completions = [] + for context, until in tqdm(zip(contexts, untils), total=len(contexts)): - if ( - hasattr(self._tokenizer, "apply_chat_template") - and self._tokenizer.chat_template is not None - ): + if self._tokenizer.chat_template is not None: messages = [{"role": "user", "content": context}] context = self._tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) + else: + context = self._tokenizer.encode(context) max_tokens = min( self._max_tokens, - self._tokenizer.model_max_length - len(self._tokenizer.encode(context)), + self._tokenizer.model_max_length - len(context), ) text = "" for response in stream_generate( @@ -321,6 +325,12 @@ def main(): type=int, help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", ) + parser.add_argument( + "--limit", + default=1.0, + help="Limit the number of examples per task.", + type=float, + ) parser.add_argument("--seed", type=int, default=123, help="Random seed.") args = parser.parse_args() @@ -338,10 +348,12 @@ def main(): model=lm, tasks=args.tasks, num_fewshot=args.num_shots, + limit=args.limit, random_seed=args.seed, numpy_random_seed=args.seed, torch_random_seed=args.seed, fewshot_random_seed=args.seed, + apply_chat_template=True, ) model_name = args.model.replace("/", "_") diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index c7512b3c..4a7020f1 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model) # User turn prompt = "Hi my name is ." messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True -) +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) # Assistant response response = generate( @@ -32,9 +30,7 @@ response = generate( # User turn prompt = "What's my name?" messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True -) +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) # Assistant response response = generate( diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index e6535b47..41eaf1da 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}] # Transform the prompt into the chat template prompt = tokenizer.apply_chat_template( - conversation=conversation, tokenize=False, add_generation_prompt=True + conversation=conversation, add_generation_prompt=True ) # Specify the maximum number of tokens diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index afb1394e..1ea66384 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -190,10 +190,7 @@ def main(): prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") prompt = sys.stdin.read() if prompt == "-" else prompt - if not args.ignore_chat_template and ( - hasattr(tokenizer, "apply_chat_template") - and tokenizer.chat_template is not None - ): + if not args.ignore_chat_template and tokenizer.chat_template is not None: if args.system_prompt is not None: messages = [{"role": "system", "content": args.system_prompt}] else: @@ -214,6 +211,10 @@ def main(): ) prompt = prompt[test_prompt.index("") :] + prompt = tokenizer.encode(prompt, add_special_tokens=False) + else: + prompt = tokenizer.encode(prompt) + sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index c96e75a7..6fb86917 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -2,6 +2,7 @@ import argparse import math +import os import re import types from pathlib import Path @@ -271,6 +272,7 @@ def run(args, training_callback: TrainingCallback = None): def main(): + os.environ["TOKENIZERS_PARALLELISM"] = "true" parser = build_parser() args = parser.parse_args() config = args.config diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index c12513ff..4523e3ae 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -590,14 +590,10 @@ class APIHandler(BaseHTTPRequestHandler): # Determine response type self.request_id = f"chatcmpl-{uuid.uuid4()}" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" - if ( - hasattr(self.tokenizer, "apply_chat_template") - and self.tokenizer.chat_template - ): + if self.tokenizer.chat_template: prompt = self.tokenizer.apply_chat_template( body["messages"], body.get("tools", None), - tokenize=True, add_generation_prompt=True, ) else: diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 20b32eff..fa848f47 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -10,41 +10,47 @@ class Dataset: Light-weight wrapper to hold a dataset. """ - def __init__(self, data: List[Dict[str, str]], text_key: str = "text"): - self._text_key = text_key - self._data = data + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + text_key: str = "text", + ): + self._data = [tokenizer.encode(d[text_key]) for d in data] + for d in self._data: + if d[-1] != tokenizer.eos_token_id: + d.append(tokenizer.eos_token_id) def __getitem__(self, idx: int): - return self._data[idx][self._text_key] + return self._data[idx] def __len__(self): - if self._data is None: - return 0 return len(self._data) -class ChatDataset(Dataset): +class ChatDataset: """ A dataset for chat data in the format of {"messages": [...]} https://platform.openai.com/docs/guides/fine-tuning/example-format """ def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): - super().__init__(data) - self._tokenizer = tokenizer + self._data = [ + tokenizer.apply_chat_template( + d["messages"], + tools=d.get("tools", None), + ) + for d in data + ] def __getitem__(self, idx: int): - messages = self._data[idx]["messages"] - text = self._tokenizer.apply_chat_template( - messages, - tools=self._data[idx].get("tools", None), - tokenize=False, - add_generation_prompt=True, - ) - return text + return self._data[idx] + + def __len__(self): + return len(self._data) -class CompletionsDataset(Dataset): +class CompletionsDataset: """ A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} or using user-provided keys for prompt and completion values @@ -58,25 +64,24 @@ class CompletionsDataset(Dataset): prompt_key: str = "prompt", completion_key: str = "completion", ): - super().__init__(data) - self._tokenizer = tokenizer - self._prompt_key = prompt_key - self._completion_key = completion_key + self._data = [ + tokenizer.apply_chat_template( + [ + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[completion_key]}, + ], + ) + for d in data + ] def __getitem__(self, idx: int): - data = self._data[idx] - text = self._tokenizer.apply_chat_template( - [ - {"role": "user", "content": data[self._prompt_key]}, - {"role": "assistant", "content": data[self._completion_key]}, - ], - tokenize=False, - add_generation_prompt=True, - ) - return text + return self._data[idx] + + def __len__(self): + return len(self._data) -def create_dataset(data, tokenizer: PreTrainedTokenizer = None): +def create_dataset(data, tokenizer: PreTrainedTokenizer): sample = data[0] if "messages" in sample: @@ -84,7 +89,7 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None): elif "prompt" in sample and "completion" in sample: return CompletionsDataset(data, tokenizer) elif "text" in sample: - return Dataset(data) + return Dataset(data, tokenizer) else: raise ValueError( "Unsupported data format, check the supported formats here:\n" @@ -143,7 +148,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): if prompt_feature and completion_feature: return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) elif text_feature: - return Dataset(train_ds, text_key=text_feature) + return Dataset(train_ds, tokenizer, text_key=text_feature) else: raise ValueError( "Specify either a prompt and completion feature or a text " @@ -166,7 +171,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", None) is not None: + if getattr(args, "hf_dataset", False): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 21b1af18..a76b8336 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) while True: indices = np.random.permutation(len(batch_idx)) for i in indices: - # Encode batch - batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] - for b in batch: - if b[-1] != tokenizer.eos_token_id: - b.append(tokenizer.eos_token_id) - + batch = [dataset[j] for j in batch_idx[i]] lengths = [len(x) for x in batch] - if max(lengths) > max_seq_length: print( f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4d69115e..0c35d07f 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -353,9 +353,13 @@ def stream_generate( tokenizer = TokenizerWrapper(tokenizer) if not isinstance(prompt, mx.array): - prompt = mx.array( - prompt if isinstance(prompt, list) else tokenizer.encode(prompt) - ) + if isinstance(prompt, str): + # Try to infer if special tokens are needed + add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( + tokenizer.bos_token + ) + prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + prompt = mx.array(prompt) detokenizer = tokenizer.detokenizer @@ -401,7 +405,7 @@ def stream_generate( def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[int]], verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, @@ -412,7 +416,7 @@ def generate( Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (str): The string prompt. + prompt (Union[str, List[int]]): The input prompt string or integer tokens. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. kwargs: The remaining options get passed to :func:`stream_generate`. @@ -425,7 +429,6 @@ def generate( ) if verbose: print("=" * 10) - print("Prompt:", prompt) text = "" for response in stream_generate(model, tokenizer, prompt, **kwargs): @@ -654,10 +657,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): prompt="hello" - if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: + if tokenizer.chat_template is not None: messages = [{{"role": "user", "content": prompt}}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) response = generate(model, tokenizer, prompt=prompt, verbose=True) diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py index 240bfb4a..dd86d277 100644 --- a/llms/tests/test_datsets.py +++ b/llms/tests/test_datsets.py @@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase): data = {"text": "This is an example for the model."} self.save_data(4 * [data]) args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) - train, valid, test = datasets.load_dataset(args, None) + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) + train, valid, test = datasets.load_dataset(args, tokenizer) self.assertEqual(len(train), 4) self.assertEqual(len(valid), 4) self.assertEqual(len(test), 0) @@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase): "name": "billsum", "prompt_feature": "text", "completion_feature": "summary", + "train_split": "train[:2%]", + "valid_split": "train[-2%:]", }, test=False, train=True, From 25ec2d8c4496be68acf7e0c9ea1ae4269e1a2a19 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 5 Jan 2025 22:26:05 -0800 Subject: [PATCH 05/15] Change the eos-token argument for mlx_lm.generate (#1176) --- llms/mlx_lm/generate.py | 9 +++++---- llms/mlx_lm/tokenizer_utils.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 1ea66384..3301edae 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -43,10 +43,11 @@ def setup_arg_parser(): help="Optional path for the trained adapter weights and config.", ) parser.add_argument( - "--eos-token", + "--extra-eos-token", type=str, default=None, - help="End of sequence token for tokenizer", + nargs="+", + help="Add tokens in the list of eos tokens that stop generation.", ) parser.add_argument( "--system-prompt", @@ -161,8 +162,6 @@ def main(): {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) tokenizer_config["trust_remote_code"] = True - if args.eos_token is not None: - tokenizer_config["eos_token"] = args.eos_token model_path = args.model if using_cache: @@ -181,6 +180,8 @@ def main(): adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, ) + for eos_token in args.extra_eos_token: + tokenizer.add_eos_token(eos_token) if args.use_default_chat_template: if tokenizer.chat_template is None: diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index ca3d6c06..1b5bdd77 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -266,6 +266,18 @@ class TokenizerWrapper: else {tokenizer.eos_token_id} ) + def add_eos_token(self, token: str): + token_id = None + try: + token_id = int(token) + except ValueError: + token_id = self._tokenizer.convert_tokens_to_ids(token) + + if token_id is None: + raise ValueError(f"'{token}' is not a token for this tokenizer") + + self._eos_token_ids.add(token_id) + def __getattr__(self, attr): if attr == "detokenizer": return self._detokenizer From f2619f507c7dcde70410cc2cbb1d4715476d79ee Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Mon, 6 Jan 2025 10:58:43 -0500 Subject: [PATCH 06/15] Add support for fewshot and apply chat template lm_eval functionality (#1180) * Add support for multiturn fewshot examples and chat templates Added two new arguments to the evaluation script: `--fewshot-as-multiturn` and `--apply-chat-template` which correspond to lm_eval options of similar names and are very often used to ensure apples-to-apples comparisons of lm_evaluation results * Add HF overrides for methods needed by added options * don't add duplicate bos --------- Co-authored-by: Awni Hannun --- .circleci/config.yml | 2 +- llms/mlx_lm/evaluate.py | 59 +++++++++++++++++++++++++++++------------ llms/setup.py | 4 +-- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index cecd2d57..8367281e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,7 +32,7 @@ jobs: pip install --upgrade pip pip install unittest-xml-reporting cd llms/ - pip install -e ".[testing]" + pip install -e ".[test]" - run: name: Run Python tests command: | diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index bf7bf4d4..ca5e83bb 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -77,15 +77,19 @@ class MLXLM(LM): path_or_hf_repo: str, batch_size: int = 16, max_tokens: Optional[int] = None, + use_chat_template: Optional[bool] = None, ) -> None: super().__init__() self._batch_size = batch_size - self._model, self._tokenizer = load(path_or_hf_repo) - self._max_tokens = max_tokens or self._tokenizer.model_max_length + self._model, self.tokenizer = load(path_or_hf_repo) + self._max_tokens = max_tokens or self.tokenizer.model_max_length + self.use_chat_template = use_chat_template or ( + self.tokenizer.chat_template is not None + ) def _score_fn(self, inputs, tokenize=True, step_size=32): if tokenize: - inputs = self._tokenizer.encode(inputs) + inputs = self._tokenize(inputs) inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) inputs = mx.array(inputs) inputs, targets = inputs[..., :-1], inputs[..., 1:] @@ -149,7 +153,12 @@ class MLXLM(LM): return results def _tokenize(self, texts): - return [tuple(self._tokenizer.encode(t)) for t in texts] + return [ + tuple( + self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template) + ) + for t in texts + ] def loglikelihood(self, requests) -> list[tuple[float, bool]]: """Compute log-likelihood of generating a continuation from a context. @@ -221,6 +230,9 @@ class MLXLM(LM): ) return [(r[0], r[1] == r[2]) for r in results] + tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name + apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template + def loglikelihood_rolling(self, requests) -> list[float]: """Compute full log-likelihood of a string, with no truncation, for perplexity computation - We will use the full max context length of the model. @@ -283,21 +295,14 @@ class MLXLM(LM): completions = [] for context, until in tqdm(zip(contexts, untils), total=len(contexts)): - if self._tokenizer.chat_template is not None: - messages = [{"role": "user", "content": context}] - context = self._tokenizer.apply_chat_template( - messages, add_generation_prompt=True - ) - else: - context = self._tokenizer.encode(context) - + context = self._tokenize(context) max_tokens = min( self._max_tokens, - self._tokenizer.model_max_length - len(context), + self.tokenizer.model_max_length - len(context), ) text = "" for response in stream_generate( - self._model, self._tokenizer, prompt=context, max_tokens=max_tokens + self._model, self.tokenizer, prompt=context, max_tokens=max_tokens ): text += response.text if any(u in text for u in until): @@ -332,6 +337,21 @@ def main(): type=float, ) parser.add_argument("--seed", type=int, default=123, help="Random seed.") + parser.add_argument( + "--fewshot-as-multiturn", + action="store_true", + help="Whether to provide the fewshot examples as a multiturn " + "conversation or a single user turn.", + default=False, + ) + parser.add_argument( + "--apply-chat-template", + action=argparse.BooleanOptionalAction, + help="Specifies whether to apply a chat template to the prompt. If " + "the model has a chat template, this defaults to `True`, " + "otherwise `False`.", + default=None, + ) args = parser.parse_args() output_dir = Path(args.output_dir) @@ -342,18 +362,23 @@ def main(): mx.random.seed(args.seed) - lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens) - + lm = MLXLM( + args.model, + batch_size=args.batch_size, + max_tokens=args.max_tokens, + use_chat_template=args.apply_chat_template, + ) results = lm_eval.simple_evaluate( model=lm, tasks=args.tasks, + fewshot_as_multiturn=args.fewshot_as_multiturn, + apply_chat_template=lm.use_chat_template, num_fewshot=args.num_shots, limit=args.limit, random_seed=args.seed, numpy_random_seed=args.seed, torch_random_seed=args.seed, fewshot_random_seed=args.seed, - apply_chat_template=True, ) model_name = args.model.replace("/", "_") diff --git a/llms/setup.py b/llms/setup.py index b88dcd33..e6fddbae 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -27,8 +27,8 @@ setup( packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], python_requires=">=3.8", extras_require={ - "testing": ["datasets"], - "evaluation": ["lm-eval"], + "test": ["datasets"], + "evaluate": ["lm-eval", "tqdm"], }, entry_points={ "console_scripts": [ From 9183fe8b6d6b5e86cac0f47b54675f272c9f3591 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 6 Jan 2025 10:12:07 -0800 Subject: [PATCH 07/15] fix (#1192) --- llms/mlx_lm/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 3301edae..26481d6b 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -45,7 +45,7 @@ def setup_arg_parser(): parser.add_argument( "--extra-eos-token", type=str, - default=None, + default=(), nargs="+", help="Add tokens in the list of eos tokens that stop generation.", ) From b8f0cacfa8dd08aaca7025351a7afddd481ca490 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 7 Jan 2025 18:18:31 +0100 Subject: [PATCH 08/15] Use upload_large_folder (#1193) --- llms/mlx_lm/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0c35d07f..ad79349e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -673,12 +673,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): api = HfApi() api.create_repo(repo_id=upload_repo, exist_ok=True) - api.upload_folder( + api.upload_large_folder( folder_path=path, repo_id=upload_repo, repo_type="model", - multi_commits=True, - multi_commits_verbose=True, ) print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") From 40b88eff488d82b8d8739de6d60f59c1f0789a14 Mon Sep 17 00:00:00 2001 From: Jarrett <2613089+jjaareet@users.noreply.github.com> Date: Thu, 9 Jan 2025 12:33:54 -0700 Subject: [PATCH 09/15] fix(lora): config yaml & arg default merge bug (#1196) --- llms/mlx_lm/lora.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 6fb86917..4d050bd5 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -58,6 +58,8 @@ CONFIG_DEFAULTS = { "test": False, "test_batches": 500, "max_seq_length": 2048, + "config": None, + "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } @@ -67,6 +69,7 @@ def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser.add_argument( "--model", + type=str, help="The path to the local model directory or Hugging Face repo.", ) @@ -75,7 +78,6 @@ def build_parser(): "--train", action="store_true", help="Do training", - default=None, ) parser.add_argument( "--data", @@ -89,7 +91,6 @@ def build_parser(): "--fine-tune-type", type=str, choices=["lora", "dora", "full"], - default="lora", help="Type of fine-tuning to perform: lora, dora, or full.", ) parser.add_argument( @@ -134,7 +135,6 @@ def build_parser(): "--test", action="store_true", help="Evaluate on the test set after training", - default=None, ) parser.add_argument( "--test-batches", @@ -149,16 +149,15 @@ def build_parser(): parser.add_argument( "-c", "--config", - default=None, + type=str, help="A YAML configuration file with the training options", ) parser.add_argument( "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", - default=None, ) - parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") + parser.add_argument("--seed", type=int, help="The PRNG seed") return parser From 5cae0a60e6acb3599483a9304aebbc89e0bff1c4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 9 Jan 2025 15:55:53 -0800 Subject: [PATCH 10/15] deepseek v3 model with pipeline parallelism (#1191) * deepseekv3 * use upload_large_file instead of deprecated multi comit * add pipeline generation and example * comment * get fp16 working * use mlx==0.22 --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/examples/pipeline_generate.py | 75 ++++ llms/mlx_lm/models/deepseek_v3.py | 460 ++++++++++++++++++++++ llms/mlx_lm/requirements.txt | 2 +- llms/mlx_lm/utils.py | 4 +- llms/tests/test_models.py | 37 ++ llms/tests/test_utils_load_model.py | 2 +- 7 files changed, 577 insertions(+), 5 deletions(-) create mode 100644 llms/mlx_lm/examples/pipeline_generate.py create mode 100644 llms/mlx_lm/models/deepseek_v3.py diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 3af2d5fd..b2f98e6f 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.20.4" +__version__ = "0.21.0" diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py new file mode 100644 index 00000000..b98e757b --- /dev/null +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -0,0 +1,75 @@ +# Copyright © 2024 Apple Inc. + +""" +Run with: + +``` +/path/to/mpirun \ + -np 2 \ + --hostfile /path/to/hosts.txt \ + python /path/to/pipeline_generate.py --prompt "hello world" +``` + +Make sure you can run MLX over MPI on two hosts. For more information see the +documentation: + +https://ml-explore.github.io/mlx/build/html/usage/distributed.html). +""" + +import argparse + +import mlx.core as mx +from mlx_lm import load, stream_generate + +parser = argparse.ArgumentParser(description="LLM pipelined inference example") +parser.add_argument( + "--prompt", + "-p", + default="Write a quicksort in C++.", + help="Message to be processed by the model ('-' reads from stdin)", +) +parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=256, + help="Maximum number of tokens to generate", +) +args = parser.parse_args() + +model_repo = "mlx-community/DeepSeek-V3-3bit" + +model, tokenizer = load(model_repo, lazy=True) + +messages = [{"role": "user", "content": args.prompt}] +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + +group = mx.distributed.init() +rank = group.rank() +model.model.pipeline(group) +mx.eval(model.parameters()) + +# Synchronize processes before generation to avoid timeout if downloading +# model for the first time. +mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) + + +def rprint(*args, **kwargs): + if rank == 0: + print(*args, **kwargs) + + +for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): + rprint(response.text, end="", flush=True) + +rprint() +rprint("=" * 10) +rprint( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" +) +rprint( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" +) +rprint(f"Peak memory: {response.peak_memory:.3f} GB") diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py new file mode 100644 index 00000000..f95949f9 --- /dev/null +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -0,0 +1,460 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "deepseek_v3" + vocab_size: int = 102400 + hidden_size: int = 4096 + intermediate_size: int = 11008 + moe_intermediate_size: int = 1407 + num_hidden_layers: int = 30 + num_attention_heads: int = 32 + num_key_value_heads: int = 32 + n_shared_experts: Optional[int] = None + n_routed_experts: Optional[int] = None + routed_scaling_factor: float = 1.0 + kv_lora_rank: int = 512 + q_lora_rank: int = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + qk_nope_head_dim: int = 128 + topk_method: str = "noaux_tc" + scoring_func: str = "sigmoid" + norm_topk_prob: bool = True + n_group: Optional[int] = None + topk_group: Optional[int] = None + num_experts_per_tok: Optional[int] = None + moe_layer_freq: int = 1 + first_k_dense_replace: int = 0 + max_position_embeddings: int = 2048 + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000.0 + rope_scaling: Dict = None + attention_bias: bool = False + + +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity + + linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val) + return mx.clip(linear_func, 0, 1) + + +class DeepseekV3YarnRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + super().__init__() + self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( + scaling_factor, mscale_all_dim + ) + freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) + freq_inter = scaling_factor * base ** ( + mx.arange(0, dim, 2, dtype=mx.float32) / dim + ) + low, high = yarn_find_correction_range( + beta_fast, + beta_slow, + dim, + base, + original_max_position_embeddings, + ) + freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + self._freqs = (freq_inter * freq_extra) / ( + freq_inter * freq_mask + freq_extra * (1 - freq_mask) + ) + + def __call__(self, x, offset=0): + if self.mscale != 1.0: + x = self.mscale * x + return mx.fast.rope( + x, + x.shape[-1], + traditional=True, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +class DeepseekV3Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.scale = self.q_head_dim**-0.5 + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, self.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank) + self.q_b_proj = nn.Linear( + self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale + + rope_kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rope = DeepseekV3YarnRotaryEmbedding( + dim=self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **rope_kwargs, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + + if cache is not None: + q_pe = self.rope(q_pe, cache.offset) + k_pe = self.rope(k_pe, cache.offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + keys, values = cache.update_and_fetch( + mx.concatenate([k_nope, k_pe], axis=-1), values + ) + else: + q_pe = self.rope(q_pe) + k_pe = self.rope(k_pe) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + keys = mx.concatenate([k_nope, k_pe], axis=-1) + + queries = mx.concatenate([q_nope, q_pe], axis=-1) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class DeepseekV3MLP(nn.Module): + def __init__( + self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def __call__(self, x): + down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) + self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) + + def __call__(self, x): + gates = x @ self.weight.T + + scores = mx.sigmoid(gates.astype(mx.float32)) + + assert self.topk_method == "noaux_tc", "Unsupported topk method." + bsz, seq_len = x.shape[:2] + scores = scores + self.e_score_correction_bias + scores = scores.reshape(bsz, seq_len, self.n_group, -1) + group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1) + k = self.n_group - self.topk_group + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] + batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) + seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) + scores[batch_idx, seq_idx, group_idx] = 0.0 + scores = scores.reshape(bsz, seq_len, -1) + + k = self.top_k + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(scores, inds, axis=-1) + if self.top_k > 1 and self.norm_topk_prob: + denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 + scores = scores / denominator + scores = scores * self.routed_scaling_factor + + return inds, scores + + +class DeepseekV3MoE(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + self.switch_mlp = SwitchGLU( + config.hidden_size, config.moe_intermediate_size, config.n_routed_experts + ) + + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=intermediate_size + ) + + def __call__(self, x): + inds, scores = self.gate(x) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(x) + + return y + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, layer_idx: int): + super().__init__() + self.self_attn = DeepseekV3Attention(config) + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + # Protect against overflow for fp16 + if out.dtype == mx.float16: + out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000) + return out + + +class DeepseekV3Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [ + DeepseekV3DecoderLayer(config, idx) + for idx in range(config.num_hidden_layers) + ] + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pipeline_rank = 0 + self.pipeline_size = 1 + + def pipeline(self, group): + # Split layers in reverse so rank=0 gets the last layers and + # rank=pipeline_size-1 gets the first + self.pipeline_rank = group.rank() + self.pipeline_size = group.size() + layers_per_rank = ( + len(self.layers) + self.pipeline_size - 1 + ) // self.pipeline_size + start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank + self.layers = self.layers[start : start + layers_per_rank] + + def __call__( + self, + x: mx.array, + cache: Optional[Any] = None, + mask: Optional[mx.array] = None, + ) -> mx.array: + h = self.embed_tokens(x) + + pipeline_rank = self.pipeline_rank + pipeline_size = self.pipeline_size + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + # Receive from the previous process in the pipeline + if pipeline_rank < pipeline_size - 1: + h = mx.distributed.recv_like(h, (pipeline_rank + 1)) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + # Send to the next process in the pipeline + if pipeline_rank != 0: + h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size) + + # Broadcast h while keeping it in the graph + h = mx.distributed.all_gather(h)[: h.shape[0]] + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.args = config + self.model_type = config.model_type + self.model = DeepseekV3Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + mask: Optional[mx.array] = None, + ): + out = self.model(inputs, cache, mask) + return self.lm_head(out) + + def sanitize(self, weights): + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{m}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") + for e in range(self.args.n_routed_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) + + # Remove multi-token prediction layer + return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")} + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 48012863..72e1ef89 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.19.2 +mlx>=0.22.0 numpy transformers[sentencepiece]>=4.39.3 protobuf diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ad79349e..0e06b5a0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -561,7 +561,7 @@ def load( Defaults to an empty dictionary. adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers to the model. Default: ``None``. - lazy (bool): If False eval the model parameters to make sure they are + lazy (bool): If ``False`` eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` Returns: @@ -655,7 +655,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): model, tokenizer = load("{upload_repo}") - prompt="hello" + prompt = "hello" if tokenizer.chat_template is not None: messages = [{{"role": "user", "content": prompt}}] diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 7b4376bb..118ec6f2 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -682,6 +682,43 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_deepseek_v3(self): + from mlx_lm.models import deepseek_v3 + + args = deepseek_v3.ModelArgs( + model_type="deepseek_v3", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + n_routed_experts=4, + n_group=2, + topk_group=1, + num_experts_per_tok=2, + n_shared_experts=1, + kv_lora_rank=4, + q_lora_rank=4, + qk_rope_head_dim=32, + v_head_dim=16, + qk_nope_head_dim=32, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + ) + model = deepseek_v3.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_gemma2(self): from mlx_lm.models import gemma2 diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py index 5821f9e9..8da19afb 100644 --- a/llms/tests/test_utils_load_model.py +++ b/llms/tests/test_utils_load_model.py @@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): self.config = args self.custom_attribute = "This is a custom model" - def load_weights(self, weights): + def load_weights(self, weights, **kwargs): self.qwenWeights = weights class CustomQwenConfig: From 93c5cfd7819cac681bd35f8c928f752d72da8334 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 10 Jan 2025 15:27:08 -0800 Subject: [PATCH 11/15] Add a speculative decoding generator (#1155) * add a speculative decoding generator * fix * fixes * optional kwarg pop --- llms/mlx_lm/generate.py | 21 +++- llms/mlx_lm/utils.py | 209 ++++++++++++++++++++++++++++++++++------ 2 files changed, 198 insertions(+), 32 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 26481d6b..0d286c75 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -131,6 +131,18 @@ def setup_arg_parser(): type=int, default=DEFAULT_QUANTIZED_KV_START, ) + parser.add_argument( + "--draft-model", + type=str, + help="A model to be used for speculative decoding.", + default=None, + ) + parser.add_argument( + "--num-draft-tokens", + type=int, + help="Number of tokens to draft when using speculative decoding.", + default=2, + ) return parser @@ -211,11 +223,16 @@ def main(): add_generation_prompt=True, ) prompt = prompt[test_prompt.index("") :] - prompt = tokenizer.encode(prompt, add_special_tokens=False) else: prompt = tokenizer.encode(prompt) + if args.draft_model is not None: + draft_model, draft_tokenizer = load(args.draft_model) + if draft_tokenizer.vocab_size != tokenizer.vocab_size: + raise ValueError("Draft model tokenizer does not match model tokenizer.") + else: + draft_model = None sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, @@ -229,6 +246,8 @@ def main(): kv_bits=args.kv_bits, kv_group_size=args.kv_group_size, quantized_kv_start=args.quantized_kv_start, + draft_model=draft_model, + num_draft_tokens=args.num_draft_tokens, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0e06b5a0..2fc0446b 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -2,6 +2,7 @@ import contextlib import copy +import functools import glob import importlib import json @@ -207,12 +208,6 @@ def generate_step( kv_group_size: int = 64, quantized_kv_start: int = 0, prompt_progress_callback: Optional[Callable[int, int]] = None, - temp: Optional[float] = None, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = None, - top_p: Optional[float] = None, - min_p: Optional[float] = None, - min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -256,25 +251,17 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") - if temp is not None or top_p is not None or min_tokens_to_keep is not None: - print( - "[Warning] Specifying sampling arguments to ``generate_step`` is " - "deprecated. Pass in a ``sampler`` instead." - ) - if repetition_penalty is not None: - print( - "[Warning] Specifying ``repetition_penalty`` is deprecated. " - "Pass in ``logits_processors`` instead." - ) - - sampler = sampler or make_sampler( - temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 - ) - logits_processors = logits_processors or make_logits_processors( - None, repetition_penalty, repetition_context_size or 20 - ) prompt_progress_callback = prompt_progress_callback or (lambda *_: None) + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + def _step(y): with mx.stream(generation_stream): logits = model(y[None], cache=prompt_cache) @@ -287,9 +274,7 @@ def generate_step( for processor in logits_processors: logits = processor(tokens, logits) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) + quantize_cache_fn(prompt_cache) logprobs = logits - mx.logsumexp(logits, keepdims=True) y = sampler(logprobs) @@ -300,9 +285,7 @@ def generate_step( prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) + quantize_cache_fn(prompt_cache) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size @@ -329,10 +312,162 @@ def generate_step( n += 1 +def speculative_generate_step( + prompt: mx.array, + model: nn.Module, + draft_model: nn.Module, + *, + num_draft_tokens=2, + max_tokens: int = 256, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + draft_model (nn.Module): The draft model for speculative decoding. + num_draft_tokens (int, optional): The number of draft tokens for + speculative decoding. Default: ``2``. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. The cache must be trimmable. + prefill_step_size (int): Step size for processing the prompt. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. + + Yields: + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + """ + + y = prompt + tokens = None + + # Create the KV cache for generation + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + draft_cache = cache.make_prompt_cache(draft_model) + elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)): + raise ValueError("Wrong number of layers in the prompt cache.") + else: + model_cache = prompt_cache[: len(model.layers)] + draft_cache = prompt_cache[len(model.layers) :] + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _step(model, cache, y, n_predict=1): + with mx.stream(generation_stream): + logits = model(y[None], cache=cache) + logits = logits[:, -n_predict:, :] + + quantize_cache_fn(cache) + + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs).squeeze(0) + return y, logprobs.squeeze(0) + + def _prefill(model, cache, y): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=cache) + quantize_cache_fn(cache) + mx.eval([c.state for c in cache]) + y = y[prefill_step_size:] + mx.metal.clear_cache() + return y + + def _rewind_cache(num_draft, num_accept): + cache.trim_prompt_cache(model_cache, num_draft - num_accept) + cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0)) + + def _draft_generate(y, num_draft): + if num_draft == 0: + return mx.array([], mx.uint32) + ys = [] + for _ in range(num_draft): + y, _ = _step(draft_model, draft_cache, y) + mx.async_eval(y) + ys.append(y) + return mx.concatenate(ys) + + with mx.stream(generation_stream): + draft_y = _prefill(draft_model, draft_cache, y) + y = _prefill(model, model_cache, y) + + ntoks = 0 + # Set these so the finally block doesn't raise + num_draft = 0 + n = 0 + try: + while True: + num_draft = min(max_tokens - ntoks, num_draft_tokens) + draft_tokens = _draft_generate(draft_y, num_draft) + y = mx.concatenate([y, draft_tokens]) + + tokens, logprobs = _step(model, model_cache, y, num_draft + 1) + mx.eval(tokens, draft_tokens) + draft_tokens = draft_tokens.tolist() + tokens = tokens.tolist() + n = 0 + while n < num_draft: + tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n] + if tn != dtn: + break + n += 1 + ntoks += 1 + yield tn, lpn + if ntoks == max_tokens: + break + if ntoks < max_tokens: + ntoks += 1 + yield tokens[n], logprobs[n] + + if ntoks == max_tokens: + break + + y = mx.array([tokens[n]], mx.uint32) + draft_y = y + + # If we accpeted all the draft tokens, include the last + # draft token in the next draft step since it hasn't been + # processed yet by the draft model + if n == num_draft: + draft_y = mx.concatenate( + [mx.array(draft_tokens[-1:], mx.uint32), draft_y] + ) + + _rewind_cache(num_draft, n) + finally: + _rewind_cache(num_draft, n) + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], + draft_model: Optional[nn.Module] = None, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -341,7 +476,11 @@ def stream_generate( Args: model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. + prompt (Union[str, mx.array, List[int]]): The input prompt string or + integer tokens. + draft_model (Optional[nn.Module]): An optional draft model. If provided + then speculative decoding is used. The draft model must use the same + tokenizer as the main model. Default: ``None``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -363,10 +502,18 @@ def stream_generate( detokenizer = tokenizer.detokenizer + if draft_model is None: + kwargs.pop("num_draft_tokens", None) + token_generator = generate_step(prompt, model, **kwargs) + else: + kwargs.pop("max_kv_size", None) + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): + for n, (token, logprobs) in enumerate(token_generator): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time From 514502da22f0dc4c1ac439bdf78c07d5ec41acf7 Mon Sep 17 00:00:00 2001 From: "Xingjun.Wang" Date: Sat, 11 Jan 2025 07:29:34 +0800 Subject: [PATCH 12/15] Support snapshot_download for ModelScope (#1194) * add MLX_USE_MODELSCOPE env * update * update snapshot_download * update * remove modelscope dependency and add import check * update * nits * fix --------- Co-authored-by: wangxingjun778 Co-authored-by: Awni Hannun --- llms/mlx_lm/utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 2fc0446b..b9037295 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -7,6 +7,7 @@ import glob import importlib import json import logging +import os import shutil import time from dataclasses import dataclass @@ -16,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn -from huggingface_hub import snapshot_download + +if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": + try: + from modelscope import snapshot_download + except ImportError: + raise ImportError( + "Please run `pip install modelscope` to activate the ModelScope." + ) +else: + from huggingface_hub import snapshot_download + from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer @@ -154,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path Path: The path to the model. """ model_path = Path(path_or_hf_repo) + if not model_path.exists(): try: model_path = Path( snapshot_download( - repo_id=path_or_hf_repo, + path_or_hf_repo, revision=revision, allow_patterns=[ "*.json", From bf2da36fc640e6bfab933ac8c10d76c86fcdb288 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 12 Jan 2025 21:58:08 +0100 Subject: [PATCH 13/15] Fix Cohere2: mask shape error (long context) (#1202) * fix mask shape error (long context) * Update llms/mlx_lm/models/cohere2.py Co-authored-by: Awni Hannun * revert layer_idx * black formatting * Update cohere2.py * format --------- Co-authored-by: Awni Hannun Co-authored-by: Awni Hannun --- llms/mlx_lm/models/cohere2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index ec0e9276..19bfa6b6 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -156,12 +156,13 @@ class CohereModel(nn.Module): ): h = self.embed_tokens(inputs) - if mask is None: - mask = create_attention_mask(h, cache) - if cache is None: cache = [None] * len(self.layers) + if mask is None: + j = self.args.sliding_window_pattern + mask = create_attention_mask(h, cache[j - 1 : j]) + for layer, c in zip(self.layers, cache): h = layer(h, mask, c) From 0228c46434157adaa48b44f9a227d2bb93354dc3 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Mon, 13 Jan 2025 13:01:18 -0500 Subject: [PATCH 14/15] Custom local dataset features (#1085) * Generalize prompt_feature and completion_feature for use in local datasets to facilitate compatibility with many other training dataset formats. * Persist configured prompt/completion key * rebase + nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 17 +++++++++-- llms/mlx_lm/tuner/datasets.py | 55 ++++++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 15676360..9eac9d7f 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -241,14 +241,25 @@ Refer to the documentation for the model you are fine-tuning for more details. {"prompt": "What is the capital of France?", "completion": "Paris."} ``` +For the `completions` data format, a different key can be used for the prompt +and completion by specifying the following in the YAML config: + +```yaml +prompt_feature: "input" +completion_feature: "output" +``` + +Here, `"input"` is the expected key instead of the default `"prompt"`, and +`"output"` is the expected key instead of `"completion"`. + `text`: ```jsonl {"text": "This is an example for the model."} ``` -Note, the format is automatically determined by the dataset. Note also, keys in -each line not expected by the loader will be ignored. +Note, the format is automatically determined by the dataset. Note also, keys +in each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than @@ -270,7 +281,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM expects. Use a YAML config to specify the Hugging Face dataset arguments. For example: -``` +```yaml hf_dataset: name: "billsum" prompt_feature: "text" diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index fa848f47..1b09c7e2 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional from transformers import PreTrainedTokenizer @@ -61,8 +61,8 @@ class CompletionsDataset: self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", - completion_key: str = "completion", + prompt_key: str, + completion_key: str, ): self._data = [ tokenizer.apply_chat_template( @@ -81,13 +81,19 @@ class CompletionsDataset: return len(self._data) -def create_dataset(data, tokenizer: PreTrainedTokenizer): +def create_dataset( + data, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): + prompt_feature = prompt_feature or "prompt" + completion_feature = completion_feature or "completion" sample = data[0] - if "messages" in sample: return ChatDataset(data, tokenizer) - elif "prompt" in sample and "completion" in sample: - return CompletionsDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) elif "text" in sample: return Dataset(data, tokenizer) else: @@ -97,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer): ) -def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): +def load_local_dataset( + data_path: Path, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer) + return create_dataset(data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] return train, valid, test -def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): +def load_hf_dataset( + data_id: str, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): from datasets import exceptions, load_dataset try: @@ -119,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): names = ("train", "valid", "test") train, valid, test = [ - create_dataset(dataset[n], tokenizer) if n in dataset.keys() else [] + ( + create_dataset( + dataset[n], tokenizer, prompt_feature, completion_feature + ) + if n in dataset.keys() + else [] + ) for n in names ] @@ -175,11 +197,18 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) + + prompt_feature = getattr(args, "prompt_feature", None) + completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): - train, valid, test = load_local_dataset(data_path, tokenizer) + train, valid, test = load_local_dataset( + data_path, tokenizer, prompt_feature, completion_feature + ) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset(args.data, tokenizer) + train, valid, test = load_hf_dataset( + args.data, tokenizer, prompt_feature, completion_feature + ) if args.train and len(train) == 0: raise ValueError( From c117af83b8cbec15523bd0d69e7a57f01237ca89 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 13 Jan 2025 10:22:32 -0800 Subject: [PATCH 15/15] fix gpt bigcode (#1204) --- llms/mlx_lm/models/gpt_bigcode.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 8415c59e..1d9794b6 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module): hidden_states = self.wte(inputs) mask = None - if hidden_states.shape[1] > 1: - - position_ids = mx.array(np.arange(L)) - hidden_states += self.wpe(position_ids) - - if mask is None: - mask = create_attention_mask(hidden_states, cache) + if mask is not None and hidden_states.shape[1] > 1: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) + position_ids = mx.array(np.arange(L)) + else: + position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L)) + + hidden_states += self.wpe(position_ids) for layer, c in zip(self.h, cache): hidden_states = layer(hidden_states, mask, cache=c)