mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-24 05:01:16 +08:00
chore(mlx-lm): add reset lora layers helper (#377)
* chore(mlx-lm): add reset lora layers helper * chore: rename the func * chore: update docstring * Update llms/mlx_lm/tuner/utils.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
20b969b412
commit
614de6652f
@ -47,23 +47,42 @@ def dequantize(model: nn.Module) -> nn.Module:
|
||||
nn.Module: The model with dequantized layers.
|
||||
"""
|
||||
de_quantize_layers = []
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, nn.QuantizedLinear):
|
||||
bias = "bias" in m
|
||||
weight = m.weight
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.QuantizedLinear):
|
||||
bias = "bias" in module
|
||||
weight = module.weight
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
m.scales,
|
||||
m.biases,
|
||||
m.group_size,
|
||||
m.bits,
|
||||
module.scales,
|
||||
module.biases,
|
||||
module.group_size,
|
||||
module.bits,
|
||||
).astype(mx.float16)
|
||||
output_dims, input_dims = weight.shape
|
||||
linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
linear.weight = weight
|
||||
if bias:
|
||||
linear.bias = m.bias
|
||||
de_quantize_layers.append((n, linear))
|
||||
linear.bias = module.bias
|
||||
de_quantize_layers.append((name, linear))
|
||||
if len(de_quantize_layers) > 0:
|
||||
model.update_modules(tree_unflatten(de_quantize_layers))
|
||||
return model
|
||||
|
||||
|
||||
def remove_lora_layers(model: nn.Module) -> nn.Module:
|
||||
"""
|
||||
Remove the LoRA layers from the model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model with LoRA layers.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model without LoRA layers.
|
||||
"""
|
||||
reset_layers = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
reset_layers.append((name, module.linear))
|
||||
if len(reset_layers) > 0:
|
||||
model.update_modules(tree_unflatten(reset_layers))
|
||||
return model
|
||||
|
Loading…
Reference in New Issue
Block a user