mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:08:37 +08:00
Add input masking for fine-tuning in documentation
Renamed the batch iteration function (iterate_delineated_batches -> iterate_completion_batches).
This commit is contained in:
parent
71d9f8cc38
commit
3496cbea46
@ -76,6 +76,11 @@ You can specify the output location with `--adapter-path`.
|
|||||||
You can resume fine-tuning with an existing adapter with
|
You can resume fine-tuning with an existing adapter with
|
||||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||||
|
|
||||||
|
### Input Masking
|
||||||
|
There are custom functions for masking the input sequence of tokens during the loss calculation to ensure
|
||||||
|
the model is not being penalized for not recreating the prompt. To fine-tune with masked input sequences,
|
||||||
|
use the `--mask-inputs` argument.
|
||||||
|
|
||||||
### Evaluate
|
### Evaluate
|
||||||
|
|
||||||
To compute test set perplexity use:
|
To compute test set perplexity use:
|
||||||
|
@ -185,7 +185,7 @@ def train_model(
|
|||||||
default_loss,
|
default_loss,
|
||||||
input_masked_loss,
|
input_masked_loss,
|
||||||
iterate_batches,
|
iterate_batches,
|
||||||
iterate_delineated_batches,
|
iterate_completion_batches,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.freeze()
|
model.freeze()
|
||||||
@ -249,7 +249,7 @@ def train_model(
|
|||||||
val_dataset=valid_set,
|
val_dataset=valid_set,
|
||||||
training_callback=training_callback,
|
training_callback=training_callback,
|
||||||
iterate_batches=(
|
iterate_batches=(
|
||||||
iterate_delineated_batches if args.mask_inputs else iterate_batches
|
iterate_completion_batches if args.mask_inputs else iterate_batches
|
||||||
),
|
),
|
||||||
loss=input_masked_loss if args.mask_inputs else default_loss,
|
loss=input_masked_loss if args.mask_inputs else default_loss,
|
||||||
)
|
)
|
||||||
|
@ -130,7 +130,7 @@ def input_length(
|
|||||||
return output_begin
|
return output_begin
|
||||||
|
|
||||||
|
|
||||||
def iterate_delineated_batches(
|
def iterate_completion_batches(
|
||||||
dataset: CompletionsDataset,
|
dataset: CompletionsDataset,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
Loading…
Reference in New Issue
Block a user