mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
more nits
This commit is contained in:
parent
6e9542a934
commit
bb2c8bcf96
@ -76,26 +76,13 @@ You can specify the output location with `--adapter-path`.
|
||||
You can resume fine-tuning with an existing adapter with
|
||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||
|
||||
### Input Masking
|
||||
There are custom functions for masking the sequence of tokens associated with the `prompt` in a completion dataset
|
||||
during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune
|
||||
with masked input sequences, use the `--mask-inputs` argument.
|
||||
|
||||
This functionality expects a ```response_template``` parameter in the configuration that is either a string representing
|
||||
a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts)
|
||||
or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of
|
||||
the sequence from loss calculations. For example (ChatML):
|
||||
|
||||
```yaml
|
||||
response_template: "<|im_start|>assistant"
|
||||
```
|
||||
|
||||
or (for the corresponding tokens of Gemma's response template)
|
||||
|
||||
```yaml
|
||||
response_template: [106, 2516]
|
||||
```
|
||||
#### Prompt Masking
|
||||
|
||||
The default training computes a loss for every token in the sample. You can
|
||||
ignore the prompt and compute loss for just the completion by passing
|
||||
`--mask-prompt`. Note this is only supported for `chat` and `completion`
|
||||
datasets. For `chat` datasets the final message in the message list is
|
||||
considered the completion. See the [dataset section](#Data) for more details.
|
||||
|
||||
### Evaluate
|
||||
|
||||
|
@ -185,6 +185,8 @@ def load_hf_dataset(
|
||||
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
import datasets
|
||||
|
||||
mask_prompt = getattr(args, "mask_prompt", False)
|
||||
|
||||
def create_hf_dataset(
|
||||
dataset_name,
|
||||
text_feature,
|
||||
@ -201,7 +203,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
)
|
||||
if prompt_feature and completion_feature:
|
||||
return CompletionsDataset(
|
||||
data, tokenizer, prompt_feature, completion_feature, mask_prompt
|
||||
ds, tokenizer, prompt_feature, completion_feature, mask_prompt
|
||||
)
|
||||
elif chat_feature:
|
||||
return ChatDataset(
|
||||
|
@ -83,16 +83,6 @@ def default_loss(model, batch, lengths):
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def contains(small_list: List, big_list: List) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the beginning and end index of the first occurrence of small_list in big_list.
|
||||
"""
|
||||
small_list_length = len(small_list)
|
||||
for ind in (i for i, e in enumerate(big_list) if e == small_list[0]):
|
||||
if big_list[ind : ind + small_list_length] == small_list:
|
||||
return ind, ind + small_list_length - 1
|
||||
|
||||
|
||||
def iterate_batches(
|
||||
dataset,
|
||||
tokenizer,
|
||||
|
Loading…
Reference in New Issue
Block a user