more nits

This commit is contained in:
Awni Hannun 2025-02-09 17:58:15 -08:00
parent 6e9542a934
commit bb2c8bcf96
3 changed files with 9 additions and 30 deletions

View File

@ -76,26 +76,13 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.
### Input Masking
There are custom functions for masking the sequence of tokens associated with the `prompt` in a completion dataset
during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune
with masked input sequences, use the `--mask-inputs` argument.
This functionality expects a ```response_template``` parameter in the configuration that is either a string representing
a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts)
or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of
the sequence from loss calculations. For example (ChatML):
```yaml
response_template: "<|im_start|>assistant"
```
or (for the corresponding tokens of Gemma's response template)
```yaml
response_template: [106, 2516]
```
#### Prompt Masking
The default training computes a loss for every token in the sample. You can
ignore the prompt and compute loss for just the completion by passing
`--mask-prompt`. Note this is only supported for `chat` and `completion`
datasets. For `chat` datasets the final message in the message list is
considered the completion. See the [dataset section](#Data) for more details.
### Evaluate

View File

@ -185,6 +185,8 @@ def load_hf_dataset(
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets
mask_prompt = getattr(args, "mask_prompt", False)
def create_hf_dataset(
dataset_name,
text_feature,
@ -201,7 +203,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
)
if prompt_feature and completion_feature:
return CompletionsDataset(
data, tokenizer, prompt_feature, completion_feature, mask_prompt
ds, tokenizer, prompt_feature, completion_feature, mask_prompt
)
elif chat_feature:
return ChatDataset(

View File

@ -83,16 +83,6 @@ def default_loss(model, batch, lengths):
return ce, ntoks
def contains(small_list: List, big_list: List) -> Tuple[int, int]:
"""
Returns the beginning and end index of the first occurrence of small_list in big_list.
"""
small_list_length = len(small_list)
for ind in (i for i, e in enumerate(big_list) if e == small_list[0]):
if big_list[ind : ind + small_list_length] == small_list:
return ind, ind + small_list_length - 1
def iterate_batches(
dataset,
tokenizer,