From bb2c8bcf968f5d17ed763eca3d18b3370316416c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 9 Feb 2025 17:58:15 -0800 Subject: [PATCH] more nits --- llms/mlx_lm/LORA.md | 25 ++++++------------------- llms/mlx_lm/tuner/datasets.py | 4 +++- llms/mlx_lm/tuner/trainer.py | 10 ---------- 3 files changed, 9 insertions(+), 30 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index d51bce82..e863abc4 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -76,26 +76,13 @@ You can specify the output location with `--adapter-path`. You can resume fine-tuning with an existing adapter with `--resume-adapter-file `. -### Input Masking -There are custom functions for masking the sequence of tokens associated with the `prompt` in a completion dataset -during the loss calculation to ensure the model is not being penalized for not recreating the prompt. To fine-tune -with masked input sequences, use the `--mask-inputs` argument. - -This functionality expects a ```response_template``` parameter in the configuration that is either a string representing -a [string that indicate the start of the model's response](https://huggingface.co/docs/transformers/en/chat_templating#what-are-generation-prompts) -or its corresopnding tokens. This is used to create the mask that excludes the tokens associated from the rest of -the sequence from loss calculations. For example (ChatML): - -```yaml -response_template: "<|im_start|>assistant" -``` - -or (for the corresponding tokens of Gemma's response template) - -```yaml -response_template: [106, 2516] -``` +#### Prompt Masking +The default training computes a loss for every token in the sample. You can +ignore the prompt and compute loss for just the completion by passing +`--mask-prompt`. Note this is only supported for `chat` and `completion` +datasets. For `chat` datasets the final message in the message list is +considered the completion. See the [dataset section](#Data) for more details. ### Evaluate diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 174d05ca..44c78450 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -185,6 +185,8 @@ def load_hf_dataset( def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): import datasets + mask_prompt = getattr(args, "mask_prompt", False) + def create_hf_dataset( dataset_name, text_feature, @@ -201,7 +203,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): ) if prompt_feature and completion_feature: return CompletionsDataset( - data, tokenizer, prompt_feature, completion_feature, mask_prompt + ds, tokenizer, prompt_feature, completion_feature, mask_prompt ) elif chat_feature: return ChatDataset( diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index f2f0cb5d..d675f9b6 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -83,16 +83,6 @@ def default_loss(model, batch, lengths): return ce, ntoks -def contains(small_list: List, big_list: List) -> Tuple[int, int]: - """ - Returns the beginning and end index of the first occurrence of small_list in big_list. - """ - small_list_length = len(small_list) - for ind in (i for i, e in enumerate(big_list) if e == small_list[0]): - if big_list[ind : ind + small_list_length] == small_list: - return ind, ind + small_list_length - 1 - - def iterate_batches( dataset, tokenizer,