diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 89e6cd00..839089b6 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.21.5" +__version__ = "0.21.6" diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index e52ad10d..5c0b78db 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -65,12 +65,25 @@ def main(): tokenizer_config={"trust_remote_code": True}, ) - print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") + def print_help(): + print("The command list:") + print("- 'q' to exit") + print("- 'r' to reset the chat") + print("- 'h' to display these commands") + + print(f"[INFO] Starting chat session with {args.model}.") + print_help() prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: query = input(">> ") if query == "q": break + if query == "r": + prompt_cache = make_prompt_cache(model, args.max_kv_size) + continue + if query == "h": + print_help() + continue messages = [{"role": "user", "content": query}] prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) for response in stream_generate( diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index 9bac77a5..86a96447 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -1,8 +1,27 @@ # Copyright © 2023-2024 Apple Inc. import argparse +from enum import Enum -from .utils import convert +from .utils import convert, mixed_2_6, mixed_3_6 + + +class MixedQuants(Enum): + mixed_3_6 = "mixed_3_6" + mixed_2_6 = "mixed_2_6" + + @classmethod + def recipe_names(cls): + return [member.name for member in cls] + + +def quant_args(arg): + try: + return MixedQuants[arg].value + except KeyError: + raise argparse.ArgumentTypeError( + f"Invalid q-recipe {arg!r}. Choose from: {MixedQuants.recipe_names()}" + ) def configure_parser() -> argparse.ArgumentParser: @@ -29,6 +48,12 @@ def configure_parser() -> argparse.ArgumentParser: parser.add_argument( "--q-bits", help="Bits per weight for quantization.", type=int, default=4 ) + parser.add_argument( + "--quant-predicate", + help=f"Mixed-bit quantization recipe. Choices: {MixedQuants.recipe_names()}", + type=quant_args, + required=False, + ) parser.add_argument( "--dtype", help="Type to save the non-quantized parameters.", diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index 2f35ade2..cd6de7ec 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -289,17 +289,15 @@ class MLXLM(LM): contexts, options = zip(*[req.args for req in requests]) # contrary to the doc the second element of the tuple contains # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} - keys = list(options[0].keys()) - assert "until" in keys - untils = [x["until"] for x in options] completions = [] - for context, until in tqdm(zip(contexts, untils), total=len(contexts)): + for context, opt in tqdm(zip(contexts, options), total=len(contexts)): + until = opt["until"] context = self.tokenizer.encode( context, add_special_tokens=not self.use_chat_template ) max_tokens = min( - self._max_tokens, + opt.get("max_gen_tokens", self._max_tokens), self.tokenizer.model_max_length - len(context), ) text = "" @@ -334,9 +332,9 @@ def main(): ) parser.add_argument( "--limit", - default=1.0, + default=100, help="Limit the number of examples per task.", - type=float, + type=int, ) parser.add_argument("--seed", type=int, default=123, help="Random seed.") parser.add_argument( diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index d8f97e5e..e40332dd 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -60,6 +60,11 @@ def setup_arg_parser(): default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) + parser.add_argument( + "--prefill-response", + default=None, + help="Prefill response to be used for the chat template", + ) parser.add_argument( "--max-tokens", "-m", @@ -219,10 +224,14 @@ def main(): messages = [] messages.append({"role": "user", "content": prompt}) + has_prefill = args.prefill_response is not None + if has_prefill: + messages.append({"role": "assistant", "content": args.prefill_response}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, - add_generation_prompt=True, + continue_final_message=has_prefill, + add_generation_prompt=not has_prefill, **template_kwargs, ) @@ -233,7 +242,8 @@ def main(): test_prompt = tokenizer.apply_chat_template( messages, tokenize=False, - add_generation_prompt=True, + continue_final_message=has_prefill, + add_generation_prompt=not has_prefill, ) prompt = prompt[test_prompt.index("") :] prompt = tokenizer.encode(prompt, add_special_tokens=False) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index f2e3c24f..7e82ce7d 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -64,6 +64,9 @@ CONFIG_DEFAULTS = { "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + "mask_prompt": False, + + # ORPO args "beta": 0.1, "dpo_loss_type": "sigmoid", "delta": 50.0, @@ -106,7 +109,7 @@ def build_parser(): "--mask-prompt", action="store_true", help="Mask the prompt in the loss when training", - default=False, + default=None, ) parser.add_argument( diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py index 9827f3dc..c06de6b3 100644 --- a/llms/mlx_lm/manage.py +++ b/llms/mlx_lm/manage.py @@ -2,7 +2,22 @@ import argparse from typing import List, Union from huggingface_hub import scan_cache_dir -from transformers.commands.user import tabulate + + +def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: + """ + Inspired by: + - stackoverflow.com/a/8356620/593036 + - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data + """ + col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] + row_format = ("{{:{}}} " * len(headers)).format(*col_widths) + lines = [] + lines.append(row_format.format(*headers)) + lines.append(row_format.format(*["-" * w for w in col_widths])) + for row in rows: + lines.append(row_format.format(*row)) + return "\n".join(lines) def ask_for_confirmation(message: str) -> bool: diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index 47e17236..5cd40a0d 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -181,30 +181,37 @@ class DeepseekV3Attention(nn.Module): bias=config.attention_bias, ) - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.scale = self.scale * mscale * mscale + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale - rope_kwargs = { - key: self.config.rope_scaling[key] - for key in [ - "original_max_position_embeddings", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - ] - if key in self.config.rope_scaling - } - self.rope = DeepseekV3YarnRotaryEmbedding( - dim=self.qk_rope_head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - **rope_kwargs, - ) + rope_kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rope = DeepseekV3YarnRotaryEmbedding( + dim=self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **rope_kwargs, + ) + else: + self.rope = nn.RoPE( + dims=self.qk_rope_head_dim, + base=self.rope_theta, + traditional=True, + ) def __call__( self, @@ -487,8 +494,12 @@ class Model(nn.Module): ] weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) - # Remove multi-token prediction layer - return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")} + # Remove multi-token prediction layer and any unused precomputed rotary freqs + return { + k: v + for k, v in weights.items() + if not k.startswith("model.layers.61") and "rotary_emb.inv_freq" not in k + } @property def layers(self): diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index d1c21e25..63e985de 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -23,8 +23,10 @@ class ModelArgs(BaseModelArgs): rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None + partial_rotary_factor: float = 1.0 max_position_embeddings: int = 131072 original_max_position_embeddings: int = 4096 + tie_word_embeddings: bool = False def __post_init__(self): if self.num_key_value_heads is None: @@ -59,9 +61,10 @@ class Attention(nn.Module): self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + rope_dim = int(head_dim * args.partial_rotary_factor) if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: self.rope = SuScaledRotaryEmbedding( - head_dim, + rope_dim, base=args.rope_theta, max_position_embeddings=args.max_position_embeddings, original_max_position_embeddings=args.original_max_position_embeddings, @@ -74,7 +77,7 @@ class Attention(nn.Module): assert isinstance(args.rope_scaling["factor"], float) rope_scale = 1 / args.rope_scaling["factor"] self.rope = nn.RoPE( - head_dim, + rope_dim, traditional=args.rope_traditional, base=args.rope_theta, scale=rope_scale, @@ -190,7 +193,8 @@ class Model(nn.Module): super().__init__() self.model_type = args.model_type self.model = Phi3Model(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.args = args def __call__( @@ -200,7 +204,11 @@ class Model(nn.Module): cache=None, ): out = self.model(inputs, mask, cache) - return self.lm_head(out) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out @property def layers(self): diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py index 9c414afd..6340c77b 100644 --- a/llms/mlx_lm/models/su_rope.py +++ b/llms/mlx_lm/models/su_rope.py @@ -51,11 +51,13 @@ class SuScaledRotaryEmbedding(nn.Module): + math.log(max_position_embeddings / original_max_position_embeddings) / math.log(original_max_position_embeddings) ) + self.dim = dims def __call__(self, x, offset: int = 0): + x[..., : self.dim] = self.scale * x[..., : self.dim] return mx.fast.rope( - self.scale * x, - x.shape[-1], + x, + self.dim, traditional=False, base=None, scale=1.0, diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 2d760743..05fac92f 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -191,6 +191,7 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path "*.py", "tokenizer.model", "*.tiktoken", + "tiktoken.model", "*.txt", "*.jsonl", ], @@ -1015,6 +1016,46 @@ def save_config( json.dump(config, fid, indent=4) +def mixed_quant_predicate_builder( + low_bits: int = 4, high_bits: int = 4, group_size: int = 64 +) -> Callable[[str, nn.Module, dict], Union[bool, dict]]: + def mixed_quant_predicate( + path: str, + module: nn.Module, + config: dict, + ) -> Union[bool, dict]: + """Implements mixed quantization predicates with similar choices to, for example, llama.cpp's Q4_K_M. + Ref: https://github.com/ggerganov/llama.cpp/blob/917786f43d0f29b7c77a0c56767c0fa4df68b1c5/src/llama.cpp#L5265 + By Alex Barron: https://gist.github.com/barronalex/84addb8078be21969f1690c1454855f3 + """ + + if not hasattr(module, "to_quantized"): + return False + + index = int(path.split(".")[2]) if len(path.split(".")) > 2 else 0 + + num_layers = config["num_hidden_layers"] + use_more_bits = ( + index < num_layers // 8 + or index >= 7 * num_layers // 8 + or (index - num_layers // 8) % 3 == 2 + ) + if "v_proj" in path and use_more_bits: + return {"group_size": group_size, "bits": high_bits} + if "down_proj" in path and use_more_bits: + return {"group_size": group_size, "bits": high_bits} + if "lm_head" in path: + return {"group_size": group_size, "bits": high_bits} + + return {"group_size": group_size, "bits": low_bits} + + return mixed_quant_predicate + + +mixed_3_6 = mixed_quant_predicate_builder(low_bits=3) +mixed_2_6 = mixed_quant_predicate_builder(low_bits=2) + + def convert( hf_path: str, mlx_path: str = "mlx_model",