From b2ab37238ee194552e305309d38eaa126b537289 Mon Sep 17 00:00:00 2001 From: paNikitin <115797306+paNikitin@users.noreply.github.com> Date: Mon, 24 Feb 2025 12:21:30 +0300 Subject: [PATCH] added adapter additional tokens load on fuse --- llms/mlx_lm/fuse.py | 5 +++-- llms/mlx_lm/tuner/utils.py | 22 ++++++++++++++++++++-- llms/mlx_lm/utils.py | 4 ++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index b0c46a74..47abd717 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -77,7 +77,8 @@ def main() -> None: model, config, tokenizer = fetch_from_hub(model_path) model.freeze() - model = load_adapters(model, args.adapter_path) + + model, tokenizer = load_adapters(model, tokenizer, args.adapter_path) fused_linears = [ (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse") @@ -105,7 +106,7 @@ def main() -> None: if args.de_quantize: config.pop("quantization", None) - save_config(config, config_path=save_path / "config.json") + save_config(config, tokenizer, config_path=save_path / "config.json") if args.export_gguf: model_type = config["model_type"] diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index f5df11e3..6948df15 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -8,6 +8,8 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as opt from mlx.utils import tree_flatten, tree_unflatten +from ..tokenizer_utils import TokenizerWrapper + from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear from .dora import DoRAEmbedding, DoRALinear @@ -159,7 +161,9 @@ def linear_to_lora_layers( model.update_modules(tree_unflatten(lora_modules)) -def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: +def load_adapters( + model: nn.Module, tokenizer: TokenizerWrapper, adapter_path: str +) -> nn.Module: """ Load any fine-tuned adapters / layers. @@ -184,7 +188,21 @@ def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: use_dora=(fine_tune_type == "dora"), ) model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) - return model + if cot := config.cot: + print("Loading additional tokens") + if tokens := cot.get("additional_tokens"): + from .new_tokens import implement_new_tokens + + special = False + if (special_arg := cot.get("special")) and isinstance(special_arg, bool): + print("Updating model and tokenizer with new special tokens") + special = special_arg + else: + print("Updating model and tokenizer with new tokens") + model, tokenizer = implement_new_tokens( + model=model, tokenizer=tokenizer, tokens=tokens, special=special + ) + return model, tokenizer def dequantize(model: nn.Module) -> nn.Module: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1fae76fa..b7bd1cba 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -993,6 +993,7 @@ def quantize_model( def save_config( config: dict, + tokenizer: TokenizerWrapper, config_path: Union[str, Path], ) -> None: """Save the model configuration to the ``config_path``. @@ -1009,6 +1010,9 @@ def save_config( # sort the config for better readability config = dict(sorted(config.items())) + if config["vocab_size"] != (cur := len(tokenizer._tokenizer)): + config["vocab_size"] = cur + print("Updated model`s config.json to match new tokenizer") # write the updated config to the config_path (if provided) with open(config_path, "w") as fid: json.dump(config, fid, indent=4)