mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Fix num layers in fine tune (#1294)
This commit is contained in:
@@ -181,8 +181,14 @@ def train_model(
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
model.freeze()
|
||||
if args.num_layers > len(model.layers):
|
||||
raise ValueError(
|
||||
f"Requested to train {args.num_layers} layers "
|
||||
f"but the model only has {len(model.layers)} layers."
|
||||
)
|
||||
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||
for l in model.layers[-max(args.num_layers, 0) :]:
|
||||
l.unfreeze()
|
||||
elif args.fine_tune_type in ["lora", "dora"]:
|
||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||
|
||||
Reference in New Issue
Block a user