mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Fix num layers in fine tune (#1294)
This commit is contained in:
parent
1cbf5cdac7
commit
85669451d0
@ -181,8 +181,14 @@ def train_model(
|
|||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
model.freeze()
|
model.freeze()
|
||||||
|
if args.num_layers > len(model.layers):
|
||||||
|
raise ValueError(
|
||||||
|
f"Requested to train {args.num_layers} layers "
|
||||||
|
f"but the model only has {len(model.layers)} layers."
|
||||||
|
)
|
||||||
|
|
||||||
if args.fine_tune_type == "full":
|
if args.fine_tune_type == "full":
|
||||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
for l in model.layers[-max(args.num_layers, 0) :]:
|
||||||
l.unfreeze()
|
l.unfreeze()
|
||||||
elif args.fine_tune_type in ["lora", "dora"]:
|
elif args.fine_tune_type in ["lora", "dora"]:
|
||||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||||
|
@ -52,11 +52,6 @@ def linear_to_lora_layers(
|
|||||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||||
Default: ``False``
|
Default: ``False``
|
||||||
"""
|
"""
|
||||||
if num_layers > len(model.layers):
|
|
||||||
raise ValueError(
|
|
||||||
f"Requested {num_layers} LoRA layers "
|
|
||||||
f"but the model only has {len(model.layers)} layers."
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_lora(layer):
|
def to_lora(layer):
|
||||||
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
|
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
|
||||||
@ -154,7 +149,7 @@ def linear_to_lora_layers(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Lora does not support {model.model_type}")
|
raise ValueError(f"Lora does not support {model.model_type}")
|
||||||
|
|
||||||
for l in model.layers[-min(num_layers, 0) :]:
|
for l in model.layers[-max(num_layers, 0) :]:
|
||||||
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
|
||||||
if lora_layers:
|
if lora_layers:
|
||||||
l.update_modules(tree_unflatten(lora_layers))
|
l.update_modules(tree_unflatten(lora_layers))
|
||||||
|
Loading…
Reference in New Issue
Block a user