mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:50:57 +08:00
Updates CL lora tuner with input masking that uses default_loss (and iterate_batches) by default.
This commit is contained in:
parent
84fc1bde48
commit
27cd361d76
@ -94,6 +94,15 @@ def build_parser():
|
|||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mask-inputs",
|
||||||
|
dest="mask_inputs",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to mask the inputs when training. Default is False.",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-layers",
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -172,6 +181,13 @@ def train_model(
|
|||||||
valid_set,
|
valid_set,
|
||||||
training_callback: TrainingCallback = None,
|
training_callback: TrainingCallback = None,
|
||||||
):
|
):
|
||||||
|
from .tuner.trainer import (
|
||||||
|
default_loss,
|
||||||
|
input_masked_loss,
|
||||||
|
iterate_batches,
|
||||||
|
iterate_delineated_batches,
|
||||||
|
)
|
||||||
|
|
||||||
model.freeze()
|
model.freeze()
|
||||||
if args.fine_tune_type == "full":
|
if args.fine_tune_type == "full":
|
||||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||||
@ -228,6 +244,10 @@ def train_model(
|
|||||||
train_dataset=train_set,
|
train_dataset=train_set,
|
||||||
val_dataset=valid_set,
|
val_dataset=valid_set,
|
||||||
training_callback=training_callback,
|
training_callback=training_callback,
|
||||||
|
iterate_batches=(
|
||||||
|
iterate_delineated_batches if args.mask_inputs else iterate_batches
|
||||||
|
),
|
||||||
|
loss=input_masked_loss if args.mask_inputs else default_loss,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user