mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:22:46 +08:00
added adapter additional tokens load on fuse
This commit is contained in:
parent
e2ace6fb0f
commit
b2ab37238e
@ -77,7 +77,8 @@ def main() -> None:
|
|||||||
model, config, tokenizer = fetch_from_hub(model_path)
|
model, config, tokenizer = fetch_from_hub(model_path)
|
||||||
|
|
||||||
model.freeze()
|
model.freeze()
|
||||||
model = load_adapters(model, args.adapter_path)
|
|
||||||
|
model, tokenizer = load_adapters(model, tokenizer, args.adapter_path)
|
||||||
|
|
||||||
fused_linears = [
|
fused_linears = [
|
||||||
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
|
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
|
||||||
@ -105,7 +106,7 @@ def main() -> None:
|
|||||||
if args.de_quantize:
|
if args.de_quantize:
|
||||||
config.pop("quantization", None)
|
config.pop("quantization", None)
|
||||||
|
|
||||||
save_config(config, config_path=save_path / "config.json")
|
save_config(config, tokenizer, config_path=save_path / "config.json")
|
||||||
|
|
||||||
if args.export_gguf:
|
if args.export_gguf:
|
||||||
model_type = config["model_type"]
|
model_type = config["model_type"]
|
||||||
|
@ -8,6 +8,8 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
from ..tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
|
||||||
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
|
||||||
from .dora import DoRAEmbedding, DoRALinear
|
from .dora import DoRAEmbedding, DoRALinear
|
||||||
@ -159,7 +161,9 @@ def linear_to_lora_layers(
|
|||||||
model.update_modules(tree_unflatten(lora_modules))
|
model.update_modules(tree_unflatten(lora_modules))
|
||||||
|
|
||||||
|
|
||||||
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
|
def load_adapters(
|
||||||
|
model: nn.Module, tokenizer: TokenizerWrapper, adapter_path: str
|
||||||
|
) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
Load any fine-tuned adapters / layers.
|
Load any fine-tuned adapters / layers.
|
||||||
|
|
||||||
@ -184,7 +188,21 @@ def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
|
|||||||
use_dora=(fine_tune_type == "dora"),
|
use_dora=(fine_tune_type == "dora"),
|
||||||
)
|
)
|
||||||
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||||
return model
|
if cot := config.cot:
|
||||||
|
print("Loading additional tokens")
|
||||||
|
if tokens := cot.get("additional_tokens"):
|
||||||
|
from .new_tokens import implement_new_tokens
|
||||||
|
|
||||||
|
special = False
|
||||||
|
if (special_arg := cot.get("special")) and isinstance(special_arg, bool):
|
||||||
|
print("Updating model and tokenizer with new special tokens")
|
||||||
|
special = special_arg
|
||||||
|
else:
|
||||||
|
print("Updating model and tokenizer with new tokens")
|
||||||
|
model, tokenizer = implement_new_tokens(
|
||||||
|
model=model, tokenizer=tokenizer, tokens=tokens, special=special
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def dequantize(model: nn.Module) -> nn.Module:
|
def dequantize(model: nn.Module) -> nn.Module:
|
||||||
|
@ -993,6 +993,7 @@ def quantize_model(
|
|||||||
|
|
||||||
def save_config(
|
def save_config(
|
||||||
config: dict,
|
config: dict,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
config_path: Union[str, Path],
|
config_path: Union[str, Path],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save the model configuration to the ``config_path``.
|
"""Save the model configuration to the ``config_path``.
|
||||||
@ -1009,6 +1010,9 @@ def save_config(
|
|||||||
# sort the config for better readability
|
# sort the config for better readability
|
||||||
config = dict(sorted(config.items()))
|
config = dict(sorted(config.items()))
|
||||||
|
|
||||||
|
if config["vocab_size"] != (cur := len(tokenizer._tokenizer)):
|
||||||
|
config["vocab_size"] = cur
|
||||||
|
print("Updated model`s config.json to match new tokenizer")
|
||||||
# write the updated config to the config_path (if provided)
|
# write the updated config to the config_path (if provided)
|
||||||
with open(config_path, "w") as fid:
|
with open(config_path, "w") as fid:
|
||||||
json.dump(config, fid, indent=4)
|
json.dump(config, fid, indent=4)
|
||||||
|
Loading…
Reference in New Issue
Block a user