diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index 60a13f94..132982d3 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -8,7 +8,7 @@ from typing import Any, Dict, Union from mlx.utils import tree_flatten, tree_unflatten from .tuner.lora import LoRALinear -from .tuner.utils import apply_lora_layers +from .tuner.utils import apply_lora_layers, dequantize from .utils import fetch_from_hub, get_model_path, save_weights, upload_to_hub @@ -42,6 +42,11 @@ def parse_arguments() -> argparse.Namespace: type=str, default=None, ) + parser.add_argument( + "--de-quantize", + help="Generate a de-quantized model.", + action="store_true", + ) return parser.parse_args() @@ -54,6 +59,7 @@ def main() -> None: model.freeze() model = apply_lora_layers(model, args.adapter_file) + fused_linears = [ (n, m.to_linear()) for n, m in model.named_modules() @@ -61,6 +67,11 @@ def main() -> None: ] model.update_modules(tree_unflatten(fused_linears)) + + if args.de_quantize: + print("De-quantizing model") + model = dequantize(model) + weights = dict(tree_flatten(model.parameters())) save_path = Path(args.save_path) @@ -73,6 +84,9 @@ def main() -> None: tokenizer.save_pretrained(save_path) + if args.de_quantize: + config.pop("quantization", None) + with open(save_path / "config.json", "w") as fid: json.dump(config, fid, indent=4) diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py index 2a64e5a0..adc1f8ca 100644 --- a/llms/mlx_lm/tuner/lora.py +++ b/llms/mlx_lm/tuner/lora.py @@ -29,7 +29,7 @@ class LoRALinear(nn.Module): lora_lin.linear = linear return lora_lin - def to_linear(self): + def to_linear(self, de_quantize: bool = False): linear = self.linear bias = "bias" in linear weight = linear.weight @@ -56,7 +56,7 @@ class LoRALinear(nn.Module): if bias: fused_linear.bias = linear.bias - if is_quantized: + if is_quantized and not de_quantize: fused_linear = nn.QuantizedLinear.from_linear( fused_linear, linear.group_size, diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 079011f0..76408107 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -1,3 +1,5 @@ +import os + import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_unflatten @@ -6,18 +8,62 @@ from .lora import LoRALinear def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module: + """ + Apply LoRA layers to the model. + + Args: + model (nn.Module): The neural network model. + adapter_file (str): Path to the adapter configuration file. + + Returns: + nn.Module: The updated model with LoRA layers applied. + """ + if not os.path.exists(adapter_file): + raise FileNotFoundError(f"The adapter file does not exist: {adapter_file}") + adapters = list(mx.load(adapter_file).items()) - linear_replacements = {} + + linear_replacements = [] lora_layers = set( [name.replace(".lora_a", "").replace(".lora_b", "") for name, _ in adapters] ) - for name, module in model.named_modules(): if name in lora_layers: replacement_module = LoRALinear.from_linear(module) - linear_replacements[name] = replacement_module + linear_replacements.append((name, replacement_module)) - model.update_modules(tree_unflatten(list(linear_replacements.items()))) - - model.update(tree_unflatten(adapters)) + model.update_modules(tree_unflatten(linear_replacements)) + return model + + +def dequantize(model: nn.Module) -> nn.Module: + """ + Dequantize the quantized linear layers in the model. + + Args: + model (nn.Module): The model with quantized linear layers. + + Returns: + nn.Module: The model with dequantized layers. + """ + de_quantize_layers = [] + for n, m in model.named_modules(): + if isinstance(m, nn.QuantizedLinear): + bias = "bias" in m + weight = m.weight + weight = mx.dequantize( + weight, + m.scales, + m.biases, + m.group_size, + m.bits, + ).astype(mx.float16) + output_dims, input_dims = weight.shape + linear = nn.Linear(input_dims, output_dims, bias=bias) + linear.weight = weight + if bias: + linear.bias = m.bias + de_quantize_layers.append((n, linear)) + if len(de_quantize_layers) > 0: + model.update_modules(tree_unflatten(de_quantize_layers)) return model diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d670ee71..4edc83c0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -24,7 +24,7 @@ MODEL_MAPPING = { "qwen": qwen, "plamo": plamo, } -MAX_FILE_SIZE_GB = 15 +MAX_FILE_SIZE_GB = 5 linear_class_predicate = ( lambda m: isinstance(m, nn.Linear)