From 614de6652faf7747737af991e34008618598da43 Mon Sep 17 00:00:00 2001 From: Anchen Date: Tue, 30 Jan 2024 15:54:49 +1100 Subject: [PATCH] chore(mlx-lm): add reset lora layers helper (#377) * chore(mlx-lm): add reset lora layers helper * chore: rename the func * chore: update docstring * Update llms/mlx_lm/tuner/utils.py Co-authored-by: Awni Hannun --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/tuner/utils.py | 39 ++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 76408107..dbcf1acf 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -47,23 +47,42 @@ def dequantize(model: nn.Module) -> nn.Module: nn.Module: The model with dequantized layers. """ de_quantize_layers = [] - for n, m in model.named_modules(): - if isinstance(m, nn.QuantizedLinear): - bias = "bias" in m - weight = m.weight + for name, module in model.named_modules(): + if isinstance(module, nn.QuantizedLinear): + bias = "bias" in module + weight = module.weight weight = mx.dequantize( weight, - m.scales, - m.biases, - m.group_size, - m.bits, + module.scales, + module.biases, + module.group_size, + module.bits, ).astype(mx.float16) output_dims, input_dims = weight.shape linear = nn.Linear(input_dims, output_dims, bias=bias) linear.weight = weight if bias: - linear.bias = m.bias - de_quantize_layers.append((n, linear)) + linear.bias = module.bias + de_quantize_layers.append((name, linear)) if len(de_quantize_layers) > 0: model.update_modules(tree_unflatten(de_quantize_layers)) return model + + +def remove_lora_layers(model: nn.Module) -> nn.Module: + """ + Remove the LoRA layers from the model. + + Args: + model (nn.Module): The model with LoRA layers. + + Returns: + nn.Module: The model without LoRA layers. + """ + reset_layers = [] + for name, module in model.named_modules(): + if isinstance(module, LoRALinear): + reset_layers.append((name, module.linear)) + if len(reset_layers) > 0: + model.update_modules(tree_unflatten(reset_layers)) + return model