mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 15:08:37 +08:00
more nits
This commit is contained in:
parent
6e9542a934
commit
bb2c8bcf96
@ -76,26 +76,13 @@ You can specify the output location with `--adapter-path`.
|
|||||||
You can resume fine-tuning with an existing adapter with
|
You can resume fine-tuning with an existing adapter with
|
||||||
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
`--resume-adapter-file <path_to_adapters.safetensors>`.
|
||||||
|
|
||||||
### Input Masking
|
#### Prompt Masking
|
||||||
There are custom functions for masking the sequence of tokens associated with the `prompt` in a completion dataset
|
|
||||||
during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune
|
|
||||||
with masked input sequences, use the `--mask-inputs` argument.
|
|
||||||
|
|
||||||
This functionality expects a ```response_template``` parameter in the configuration that is either a string representing
|
|
||||||
a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts)
|
|
||||||
or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of
|
|
||||||
the sequence from loss calculations. For example (ChatML):
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
response_template: "<|im_start|>assistant"
|
|
||||||
```
|
|
||||||
|
|
||||||
or (for the corresponding tokens of Gemma's response template)
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
response_template: [106, 2516]
|
|
||||||
```
|
|
||||||
|
|
||||||
|
The default training computes a loss for every token in the sample. You can
|
||||||
|
ignore the prompt and compute loss for just the completion by passing
|
||||||
|
`--mask-prompt`. Note this is only supported for `chat` and `completion`
|
||||||
|
datasets. For `chat` datasets the final message in the message list is
|
||||||
|
considered the completion. See the [dataset section](#Data) for more details.
|
||||||
|
|
||||||
### Evaluate
|
### Evaluate
|
||||||
|
|
||||||
|
@ -185,6 +185,8 @@ def load_hf_dataset(
|
|||||||
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
|
mask_prompt = getattr(args, "mask_prompt", False)
|
||||||
|
|
||||||
def create_hf_dataset(
|
def create_hf_dataset(
|
||||||
dataset_name,
|
dataset_name,
|
||||||
text_feature,
|
text_feature,
|
||||||
@ -201,7 +203,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
|||||||
)
|
)
|
||||||
if prompt_feature and completion_feature:
|
if prompt_feature and completion_feature:
|
||||||
return CompletionsDataset(
|
return CompletionsDataset(
|
||||||
data, tokenizer, prompt_feature, completion_feature, mask_prompt
|
ds, tokenizer, prompt_feature, completion_feature, mask_prompt
|
||||||
)
|
)
|
||||||
elif chat_feature:
|
elif chat_feature:
|
||||||
return ChatDataset(
|
return ChatDataset(
|
||||||
|
@ -83,16 +83,6 @@ def default_loss(model, batch, lengths):
|
|||||||
return ce, ntoks
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
def contains(small_list: List, big_list: List) -> Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns the beginning and end index of the first occurrence of small_list in big_list.
|
|
||||||
"""
|
|
||||||
small_list_length = len(small_list)
|
|
||||||
for ind in (i for i, e in enumerate(big_list) if e == small_list[0]):
|
|
||||||
if big_list[ind : ind + small_list_length] == small_list:
|
|
||||||
return ind, ind + small_list_length - 1
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(
|
def iterate_batches(
|
||||||
dataset,
|
dataset,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
Loading…
Reference in New Issue
Block a user