import os from typing import Dict import mlx.core as mx import mlx.nn as nn import mlx.optimizers as opt from mlx.utils import tree_unflatten from .lora import LoRALinear def build_schedule(schedule_config: Dict): """ Build a learning rate schedule from the given config. """ schedule_fn = getattr(opt.schedulers, schedule_config["name"]) arguments = schedule_config["arguments"] initial_lr = arguments[0] bound_schedule_fn = schedule_fn(*arguments) if warmup_steps := schedule_config.get("warmup", 0): warmup_init = schedule_config.get("warmup_init", 0.0) warmup_fn = opt.schedulers.linear_schedule( warmup_init, initial_lr, warmup_steps ) return opt.schedulers.join_schedules( [warmup_fn, bound_schedule_fn], [warmup_steps + 1] ) else: return bound_schedule_fn def linear_to_lora_layers( model: nn.Module, num_lora_layers: int, config: Dict, ): """ Convert some of the models linear layers to lora layers. Args: model (nn.Module): The neural network model. num_lora_layers (int): The number of blocks to convert to lora layers starting from the last layer. config (dict): More configuration parameters for LoRA, including the rank, alpha, scale, and optional layer keys. """ num_layers = len(model.layers) if num_lora_layers > num_layers: raise ValueError( f"Requested {num_lora_layers} LoRA layers " f"but the model only has {num_layers} layers." ) to_lora = lambda lin: LoRALinear.from_linear( lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"], dropout=config["dropout"], ) keys = config.get("keys", None) if keys is not None: keys = set(keys) elif model.model_type in [ "mistral", "llama", "phi", "mixtral", "stablelm", "qwen2", "gemma", "starcoder2", "cohere", ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type == "mixtral": keys.add("block_sparse_moe.gate") elif model.model_type == "olmo": keys = set(["att_proj"]) elif model.model_type == "phi-msft": keys = set(["mixer.Wqkv", "moe.gate"]) elif model.model_type == "dbrx": keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"]) else: raise ValueError(f"Lora does not support {model.model_type}") for l in model.layers[num_layers - num_lora_layers :]: modules = l.named_modules() lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] l.update_modules(tree_unflatten(lora_layers)) def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module: """ Apply LoRA layers to the model. Args: model (nn.Module): The neural network model. adapter_file (str): Path to the adapter configuration file. Returns: nn.Module: The updated model with LoRA layers applied. """ if not os.path.exists(adapter_file): raise FileNotFoundError(f"The adapter file does not exist: {adapter_file}") adapters = list(mx.load(adapter_file).items()) linear_replacements = [] lora_layers = set( [name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters] ) for name, module in model.named_modules(): if name in lora_layers: replacement_module = LoRALinear.from_linear(module) linear_replacements.append((name, replacement_module)) model.update_modules(tree_unflatten(linear_replacements)) model.update(tree_unflatten(adapters)) return model def dequantize(model: nn.Module) -> nn.Module: """ Dequantize the quantized linear layers in the model. Args: model (nn.Module): The model with quantized linear layers. Returns: nn.Module: The model with dequantized layers. """ de_quantize_layers = [] for name, module in model.named_modules(): if isinstance(module, nn.QuantizedLinear): bias = "bias" in module weight = module.weight weight = mx.dequantize( weight, module.scales, module.biases, module.group_size, module.bits, ).astype(mx.float16) output_dims, input_dims = weight.shape linear = nn.Linear(input_dims, output_dims, bias=bias) linear.weight = weight if bias: linear.bias = module.bias de_quantize_layers.append((name, linear)) if len(de_quantize_layers) > 0: model.update_modules(tree_unflatten(de_quantize_layers)) return model def remove_lora_layers(model: nn.Module) -> nn.Module: """ Remove the LoRA layers from the model. Args: model (nn.Module): The model with LoRA layers. Returns: nn.Module: The model without LoRA layers. """ reset_layers = [] for name, module in model.named_modules(): if isinstance(module, LoRALinear): reset_layers.append((name, module.linear)) if len(reset_layers) > 0: model.update_modules(tree_unflatten(reset_layers)) return model