mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
Updates CL lora tuner with input masking that uses default_loss (and iterate_batches) by default.
This commit is contained in:
parent
84fc1bde48
commit
27cd361d76
@ -94,6 +94,15 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-inputs",
|
||||
dest="mask_inputs",
|
||||
action="store_true",
|
||||
help="Whether to mask the inputs when training. Default is False.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
@ -172,6 +181,13 @@ def train_model(
|
||||
valid_set,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
from .tuner.trainer import (
|
||||
default_loss,
|
||||
input_masked_loss,
|
||||
iterate_batches,
|
||||
iterate_delineated_batches,
|
||||
)
|
||||
|
||||
model.freeze()
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||
@ -228,6 +244,10 @@ def train_model(
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
iterate_batches=(
|
||||
iterate_delineated_batches if args.mask_inputs else iterate_batches
|
||||
),
|
||||
loss=input_masked_loss if args.mask_inputs else default_loss,
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user