From 614de6652faf7747737af991e34008618598da43 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Tue, 30 Jan 2024 15:54:49 +1100
Subject: [PATCH] chore(mlx-lm): add reset lora layers helper (#377)
* chore(mlx-lm): add reset lora layers helper
* chore: rename the func
* chore: update docstring
* Update llms/mlx_lm/tuner/utils.py
Co-authored-by: Awni Hannun
---------
Co-authored-by: Awni Hannun
---
llms/mlx_lm/tuner/utils.py | 39 ++++++++++++++++++++++++++++----------
1 file changed, 29 insertions(+), 10 deletions(-)
diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py
index 76408107..dbcf1acf 100644
--- a/llms/mlx_lm/tuner/utils.py
+++ b/llms/mlx_lm/tuner/utils.py
@@ -47,23 +47,42 @@ def dequantize(model: nn.Module) -> nn.Module:
nn.Module: The model with dequantized layers.
"""
de_quantize_layers = []
- for n, m in model.named_modules():
- if isinstance(m, nn.QuantizedLinear):
- bias = "bias" in m
- weight = m.weight
+ for name, module in model.named_modules():
+ if isinstance(module, nn.QuantizedLinear):
+ bias = "bias" in module
+ weight = module.weight
weight = mx.dequantize(
weight,
- m.scales,
- m.biases,
- m.group_size,
- m.bits,
+ module.scales,
+ module.biases,
+ module.group_size,
+ module.bits,
).astype(mx.float16)
output_dims, input_dims = weight.shape
linear = nn.Linear(input_dims, output_dims, bias=bias)
linear.weight = weight
if bias:
- linear.bias = m.bias
- de_quantize_layers.append((n, linear))
+ linear.bias = module.bias
+ de_quantize_layers.append((name, linear))
if len(de_quantize_layers) > 0:
model.update_modules(tree_unflatten(de_quantize_layers))
return model
+
+
+def remove_lora_layers(model: nn.Module) -> nn.Module:
+ """
+ Remove the LoRA layers from the model.
+
+ Args:
+ model (nn.Module): The model with LoRA layers.
+
+ Returns:
+ nn.Module: The model without LoRA layers.
+ """
+ reset_layers = []
+ for name, module in model.named_modules():
+ if isinstance(module, LoRALinear):
+ reset_layers.append((name, module.linear))
+ if len(reset_layers) > 0:
+ model.update_modules(tree_unflatten(reset_layers))
+ return model