added adapter additional tokens load on fuse

This commit is contained in:
paNikitin 2025-02-24 12:21:30 +03:00
parent e2ace6fb0f
commit b2ab37238e
3 changed files with 27 additions and 4 deletions

View File

@ -77,7 +77,8 @@ def main() -> None:
model, config, tokenizer = fetch_from_hub(model_path) model, config, tokenizer = fetch_from_hub(model_path)
model.freeze() model.freeze()
model = load_adapters(model, args.adapter_path)
model, tokenizer = load_adapters(model, tokenizer, args.adapter_path)
fused_linears = [ fused_linears = [
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse") (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
@ -105,7 +106,7 @@ def main() -> None:
if args.de_quantize: if args.de_quantize:
config.pop("quantization", None) config.pop("quantization", None)
save_config(config, config_path=save_path / "config.json") save_config(config, tokenizer, config_path=save_path / "config.json")
if args.export_gguf: if args.export_gguf:
model_type = config["model_type"] model_type = config["model_type"]

View File

@ -8,6 +8,8 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as opt import mlx.optimizers as opt
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from ..tokenizer_utils import TokenizerWrapper
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRAEmbedding, DoRALinear from .dora import DoRAEmbedding, DoRALinear
@ -159,7 +161,9 @@ def linear_to_lora_layers(
model.update_modules(tree_unflatten(lora_modules)) model.update_modules(tree_unflatten(lora_modules))
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: def load_adapters(
model: nn.Module, tokenizer: TokenizerWrapper, adapter_path: str
) -> nn.Module:
""" """
Load any fine-tuned adapters / layers. Load any fine-tuned adapters / layers.
@ -184,7 +188,21 @@ def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
use_dora=(fine_tune_type == "dora"), use_dora=(fine_tune_type == "dora"),
) )
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model if cot := config.cot:
print("Loading additional tokens")
if tokens := cot.get("additional_tokens"):
from .new_tokens import implement_new_tokens
special = False
if (special_arg := cot.get("special")) and isinstance(special_arg, bool):
print("Updating model and tokenizer with new special tokens")
special = special_arg
else:
print("Updating model and tokenizer with new tokens")
model, tokenizer = implement_new_tokens(
model=model, tokenizer=tokenizer, tokens=tokens, special=special
)
return model, tokenizer
def dequantize(model: nn.Module) -> nn.Module: def dequantize(model: nn.Module) -> nn.Module:

View File

@ -993,6 +993,7 @@ def quantize_model(
def save_config( def save_config(
config: dict, config: dict,
tokenizer: TokenizerWrapper,
config_path: Union[str, Path], config_path: Union[str, Path],
) -> None: ) -> None:
"""Save the model configuration to the ``config_path``. """Save the model configuration to the ``config_path``.
@ -1009,6 +1010,9 @@ def save_config(
# sort the config for better readability # sort the config for better readability
config = dict(sorted(config.items())) config = dict(sorted(config.items()))
if config["vocab_size"] != (cur := len(tokenizer._tokenizer)):
config["vocab_size"] = cur
print("Updated model`s config.json to match new tokenizer")
# write the updated config to the config_path (if provided) # write the updated config to the config_path (if provided)
with open(config_path, "w") as fid: with open(config_path, "w") as fid:
json.dump(config, fid, indent=4) json.dump(config, fid, indent=4)