diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 76408107..dbcf1acf 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -47,23 +47,42 @@ def dequantize(model: nn.Module) -> nn.Module: nn.Module: The model with dequantized layers. """ de_quantize_layers = [] - for n, m in model.named_modules(): - if isinstance(m, nn.QuantizedLinear): - bias = "bias" in m - weight = m.weight + for name, module in model.named_modules(): + if isinstance(module, nn.QuantizedLinear): + bias = "bias" in module + weight = module.weight weight = mx.dequantize( weight, - m.scales, - m.biases, - m.group_size, - m.bits, + module.scales, + module.biases, + module.group_size, + module.bits, ).astype(mx.float16) output_dims, input_dims = weight.shape linear = nn.Linear(input_dims, output_dims, bias=bias) linear.weight = weight if bias: - linear.bias = m.bias - de_quantize_layers.append((n, linear)) + linear.bias = module.bias + de_quantize_layers.append((name, linear)) if len(de_quantize_layers) > 0: model.update_modules(tree_unflatten(de_quantize_layers)) return model + + +def remove_lora_layers(model: nn.Module) -> nn.Module: + """ + Remove the LoRA layers from the model. + + Args: + model (nn.Module): The model with LoRA layers. + + Returns: + nn.Module: The model without LoRA layers. + """ + reset_layers = [] + for name, module in model.named_modules(): + if isinstance(module, LoRALinear): + reset_layers.append((name, module.linear)) + if len(reset_layers) > 0: + model.update_modules(tree_unflatten(reset_layers)) + return model