mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
Add input masking for fine-tuning in documentation
Renamed the batch iteration function (iterate_delineated_batches -> iterate_completion_batches).
This commit is contained in:
parent
71d9f8cc38
commit
3496cbea46
@ -76,6 +76,11 @@ You can specify the output location with `--adapter-path`.
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||
|
||||
### Input Masking
|
||||
There are custom functions for masking the input sequence of tokens during the loss calculation to ensure
|
||||
the model is not being penalized for not recreating the prompt. To fine-tune with masked input sequences,
|
||||
use the `--mask-inputs` argument.
|
||||
|
||||
### Evaluate
|
||||
|
||||
To compute test set perplexity use:
|
||||
|
@ -185,7 +185,7 @@ def train_model(
|
||||
default_loss,
|
||||
input_masked_loss,
|
||||
iterate_batches,
|
||||
iterate_delineated_batches,
|
||||
iterate_completion_batches,
|
||||
)
|
||||
|
||||
model.freeze()
|
||||
@ -249,7 +249,7 @@ def train_model(
|
||||
val_dataset=valid_set,
|
||||
training_callback=training_callback,
|
||||
iterate_batches=(
|
||||
iterate_delineated_batches if args.mask_inputs else iterate_batches
|
||||
iterate_completion_batches if args.mask_inputs else iterate_batches
|
||||
),
|
||||
loss=input_masked_loss if args.mask_inputs else default_loss,
|
||||
)
|
||||
|
@ -130,7 +130,7 @@ def input_length(
|
||||
return output_begin
|
||||
|
||||
|
||||
def iterate_delineated_batches(
|
||||
def iterate_completion_batches(
|
||||
dataset: CompletionsDataset,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int,
|
||||
|
Loading…
Reference in New Issue
Block a user