From 3496cbea46f1a555c0c09a59d28b3e9b4397345c Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 10 Nov 2024 09:54:32 -0500 Subject: [PATCH] Add input masking for fine-tuning in documentation Renamed the batch iteration function (iterate_delineated_batches -> iterate_completion_batches). --- llms/mlx_lm/LORA.md | 5 +++++ llms/mlx_lm/lora.py | 4 ++-- llms/mlx_lm/tuner/trainer.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 9eac9d7f..4714c282 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -76,6 +76,11 @@ You can specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with `--resume-adapter-file `. +### Input Masking +There are custom functions for masking the input sequence of tokens during the loss calculation to ensure +the model is not being penalized for not recreating the prompt. To fine-tune with masked input sequences, +use the `--mask-inputs` argument. + ### Evaluate To compute test set perplexity use: diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index bbb574b4..ea6ee973 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -185,7 +185,7 @@ def train_model( default_loss, input_masked_loss, iterate_batches, - iterate_delineated_batches, + iterate_completion_batches, ) model.freeze() @@ -249,7 +249,7 @@ def train_model( val_dataset=valid_set, training_callback=training_callback, iterate_batches=( - iterate_delineated_batches if args.mask_inputs else iterate_batches + iterate_completion_batches if args.mask_inputs else iterate_batches ), loss=input_masked_loss if args.mask_inputs else default_loss, ) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 392030cb..99cab169 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -130,7 +130,7 @@ def input_length( return output_begin -def iterate_delineated_batches( +def iterate_completion_batches( dataset: CompletionsDataset, tokenizer: PreTrainedTokenizer, batch_size: int,