mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:12:24 +08:00
added adapter additional tokens load on fuse
This commit is contained in:
parent
e2ace6fb0f
commit
b2ab37238e
@ -77,7 +77,8 @@ def main() -> None:
|
||||
model, config, tokenizer = fetch_from_hub(model_path)
|
||||
|
||||
model.freeze()
|
||||
model = load_adapters(model, args.adapter_path)
|
||||
|
||||
model, tokenizer = load_adapters(model, tokenizer, args.adapter_path)
|
||||
|
||||
fused_linears = [
|
||||
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
|
||||
@ -105,7 +106,7 @@ def main() -> None:
|
||||
if args.de_quantize:
|
||||
config.pop("quantization", None)
|
||||
|
||||
save_config(config, config_path=save_path / "config.json")
|
||||
save_config(config, tokenizer, config_path=save_path / "config.json")
|
||||
|
||||
if args.export_gguf:
|
||||
model_type = config["model_type"]
|
||||
|
@ -8,6 +8,8 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from ..tokenizer_utils import TokenizerWrapper
|
||||
|
||||
|
||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||
from .dora import DoRAEmbedding, DoRALinear
|
||||
@ -159,7 +161,9 @@ def linear_to_lora_layers(
|
||||
model.update_modules(tree_unflatten(lora_modules))
|
||||
|
||||
|
||||
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||
def load_adapters(
|
||||
model: nn.Module, tokenizer: TokenizerWrapper, adapter_path: str
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Load any fine-tuned adapters / layers.
|
||||
|
||||
@ -184,7 +188,21 @@ def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||
use_dora=(fine_tune_type == "dora"),
|
||||
)
|
||||
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||
return model
|
||||
if cot := config.cot:
|
||||
print("Loading additional tokens")
|
||||
if tokens := cot.get("additional_tokens"):
|
||||
from .new_tokens import implement_new_tokens
|
||||
|
||||
special = False
|
||||
if (special_arg := cot.get("special")) and isinstance(special_arg, bool):
|
||||
print("Updating model and tokenizer with new special tokens")
|
||||
special = special_arg
|
||||
else:
|
||||
print("Updating model and tokenizer with new tokens")
|
||||
model, tokenizer = implement_new_tokens(
|
||||
model=model, tokenizer=tokenizer, tokens=tokens, special=special
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def dequantize(model: nn.Module) -> nn.Module:
|
||||
|
@ -993,6 +993,7 @@ def quantize_model(
|
||||
|
||||
def save_config(
|
||||
config: dict,
|
||||
tokenizer: TokenizerWrapper,
|
||||
config_path: Union[str, Path],
|
||||
) -> None:
|
||||
"""Save the model configuration to the ``config_path``.
|
||||
@ -1009,6 +1010,9 @@ def save_config(
|
||||
# sort the config for better readability
|
||||
config = dict(sorted(config.items()))
|
||||
|
||||
if config["vocab_size"] != (cur := len(tokenizer._tokenizer)):
|
||||
config["vocab_size"] = cur
|
||||
print("Updated model`s config.json to match new tokenizer")
|
||||
# write the updated config to the config_path (if provided)
|
||||
with open(config_path, "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
Loading…
Reference in New Issue
Block a user