mlx-examples/llms/mlx_lm/tuner/utils.py

276 lines
8.8 KiB
Python
Raw Normal View History

# Copyright © 2024 Apple Inc.
import json
import types
from pathlib import Path
2025-02-12 18:07:53 +08:00
from typing import Dict
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten, tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRAEmbedding, DoRALinear
from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
def build_schedule(schedule_config: Dict):
"""
Build a learning rate schedule from the given config.
"""
schedule_fn = getattr(opt.schedulers, schedule_config["name"])
arguments = schedule_config["arguments"]
initial_lr = arguments[0]
bound_schedule_fn = schedule_fn(*arguments)
if warmup_steps := schedule_config.get("warmup", 0):
warmup_init = schedule_config.get("warmup_init", 0.0)
warmup_fn = opt.schedulers.linear_schedule(
warmup_init, initial_lr, warmup_steps
)
return opt.schedulers.join_schedules(
[warmup_fn, bound_schedule_fn], [warmup_steps + 1]
)
else:
return bound_schedule_fn
def linear_to_lora_layers(
model: nn.Module,
num_layers: int,
config: Dict,
use_dora: bool = False,
):
"""
Convert some of the models linear layers to lora layers.
Args:
model (nn.Module): The neural network model.
num_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
config (dict): More configuration parameters for LoRA, including the
rank, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
def to_lora(layer):
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
LoRALayer = DoRALinear if use_dora else LoRALinear
elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)):
if use_dora:
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
LoRALayer = LoRASwitchLinear
elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
else:
raise ValueError(
f"Can't convert layer of type {type(layer).__name__} to LoRA"
)
return LoRALayer.from_base(
layer,
r=config["rank"],
scale=config["scale"],
dropout=config["dropout"],
)
keys = config.get("keys", None)
if keys is not None:
keys = set(keys)
elif model.model_type in [
"mistral",
"llama",
"phi",
"mixtral",
"nemotron",
"stablelm",
2025-02-12 08:49:35 +08:00
"hunyuan",
"qwen2",
"qwen2_moe",
"phimoe",
"gemma",
2024-06-28 01:06:28 +08:00
"gemma2",
"granite",
"helium",
"starcoder2",
2024-03-14 01:09:36 +08:00
"cohere",
"cohere2",
"minicpm",
"deepseek",
"olmo2",
"olmoe",
2025-01-16 06:55:41 +08:00
"internlm3",
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type in ["mixtral", "phimoe"]:
keys.add("block_sparse_moe.gate")
if model.model_type == "qwen2_moe":
keys.add("mlp.gate")
keys.add("mlp.shared_expert_gate")
if model.model_type == "olmoe":
keys.add("mlp.gate")
elif model.model_type == "gpt_bigcode":
keys = set(["attn.c_attn"])
elif model.model_type == "gpt2":
keys = set(["attn.c_attn"])
2024-07-11 21:13:17 +08:00
elif model.model_type == "gpt_neox":
keys = set(["attention.query_key_value"])
elif model.model_type == "olmo":
keys = set(["att_proj"])
2024-05-11 00:51:41 +08:00
elif model.model_type == "openelm":
keys = set(["attn.qkv_proj"])
elif model.model_type == "phi3":
keys = set(["self_attn.qkv_proj"])
elif model.model_type == "phi-msft":
keys = set(["mixer.Wqkv", "moe.gate"])
elif model.model_type == "dbrx":
keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"])
elif model.model_type == "internlm2":
keys = set(["attention.wqkv", "attention.wo"])
2024-08-23 01:41:21 +08:00
elif model.model_type == "deepseek_v2":
keys = set(
[
"self_attn.q_proj",
"self_attn.q_a_proj",
"self_attn.q_b_proj",
"self_attn.kv_a_proj_with_mqa",
"self_attn.kv_b_proj",
]
)
Adding support for mamba (#940) * initial commit * initial commit * Adding first lines * adding x, and dt projection layers * adding the clamping mechanism * First succesful inference * last commit for today - added custom geenrate function and it works as expected, will try training and then with loading a model from the hub * clean up * save up * almost * update * update * fixed cache handeling * fixed loading * added seperate generat_step method in the model and also in the utils to automaticaly use the generate step mthod in the model class * quick update * still not working * save * still not working * initial commit * utils.py logits = logits[:, -1, :] TypeError: tuple indices must be integers or slices, not tuple * update * update * Fixing the Batching Depfwise Comnvolution and multi token input * fixing generate and logits outputs * Done! * Fixing the cache handling, generating works now trying training * update ACKNOWLEDGEMENTS * removing the model_type if stuff in the _step loop in generate_step and adding MambaCache in base.py for training easier generations and removing mamba in tuner/utils. * quick clean up * update trainer/utils for right initialisation of the layers for LoRA, but not working. * clean up * Forther update to trainer/utils for correct layer selection. Successfull training * removing extra mamba-infer.py file * clean up, reformating will come later * reformat and big clean up, final commit * some speedups and cleanups * fix test * nits * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
2024-09-28 22:02:53 +08:00
elif model.model_type == "mamba":
keys = set(
[
"mixer.in_proj",
"mixer.x_proj",
"mixer.dt_proj",
"mixer.out_proj",
]
)
elif model.model_type == "exaone":
keys = set(["attn.attention.q_proj", "attn.attention.v_proj"])
else:
raise ValueError(f"Lora does not support {model.model_type}")
2025-02-21 05:32:01 +08:00
for l in model.layers[-max(num_layers, 0) :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
if lora_layers:
l.update_modules(tree_unflatten(lora_layers))
lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
if lora_modules:
model.update_modules(tree_unflatten(lora_modules))
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
"""
Load any fine-tuned adapters / layers.
Args:
model (nn.Module): The neural network model.
adapter_path (str): Path to the adapter configuration file.
Returns:
nn.Module: The updated model with LoRA layers applied.
"""
adapter_path = Path(adapter_path)
if not adapter_path.exists():
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid))
fine_tune_type = getattr(config, "fine_tune_type", "lora")
if fine_tune_type != "full":
linear_to_lora_layers(
model,
config.num_layers,
config.lora_parameters,
use_dora=(fine_tune_type == "dora"),
)
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model
def dequantize(model: nn.Module) -> nn.Module:
"""
Dequantize the quantized linear layers in the model.
Args:
model (nn.Module): The model with quantized linear layers.
Returns:
nn.Module: The model with dequantized layers.
"""
de_quantize_layers = []
for name, module in model.named_modules():
if isinstance(module, nn.QuantizedLinear):
bias = "bias" in module
weight = module.weight
weight = mx.dequantize(
weight,
module.scales,
module.biases,
module.group_size,
module.bits,
).astype(mx.float16)
output_dims, input_dims = weight.shape
linear = nn.Linear(input_dims, output_dims, bias=bias)
linear.weight = weight
if bias:
linear.bias = module.bias
de_quantize_layers.append((name, linear))
2024-04-20 01:46:59 +08:00
if isinstance(module, nn.QuantizedEmbedding):
weight = mx.dequantize(
module.weight,
module.scales,
module.biases,
module.group_size,
module.bits,
).astype(mx.float16)
num_embeddings, dims = weight.shape
emb = nn.Embedding(num_embeddings, dims)
emb.weight = weight
de_quantize_layers.append((name, emb))
if len(de_quantize_layers) > 0:
model.update_modules(tree_unflatten(de_quantize_layers))
return model
def remove_lora_layers(model: nn.Module) -> nn.Module:
"""
Remove the LoRA layers from the model.
Args:
model (nn.Module): The model with LoRA layers.
Returns:
nn.Module: The model without LoRA layers.
"""
reset_layers = []
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
reset_layers.append((name, module.linear))
if len(reset_layers) > 0:
model.update_modules(tree_unflatten(reset_layers))
return model
def nparams(module):
if hasattr(module, "bits"):
n = 0 if not hasattr(module, "bias") else module.bias.size
return n + module.weight.size * 32 // module.bits
return sum(v.size for _, v in tree_flatten(module.parameters()))
def print_trainable_parameters(model):
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
trainable_p = (
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
)
print(
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
2025-02-12 18:07:53 +08:00
)