diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 2d9a2553..8aec89ec 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -160,8 +160,8 @@ For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data loader expects a `test.jsonl` in the data directory. -Currently, `*.jsonl` files support three data formats: `chat`, -`completions`, and `text`. Here are three examples of these formats: +Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text` +data formats. Here are examples of these formats: `chat`: @@ -169,6 +169,58 @@ Currently, `*.jsonl` files support three data formats: `chat`, {"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]} ``` +`tools`: + +```jsonl +{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]} +``` + +
+View the expanded single data tool format + +```jsonl +{ + "messages": [ + { "role": "user", "content": "What is the weather in San Francisco?" }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_id", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}" + } + } + ] + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and country, eg. San Francisco, USA" + }, + "format": { "type": "string", "enum": ["celsius", "fahrenheit"] } + }, + "required": ["location", "format"] + } + } + } + ] +} +``` + +
+ `completions`: ```jsonl @@ -215,11 +267,13 @@ hf_dataset: - Arguments specified in `config` will be passed as keyword arguments to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). -In general, for the `chat` and `completions` formats, Hugging Face [chat -templates](https://huggingface.co/blog/chat-templates) are used. This applies -the model's chat template by default. If the model does not have a chat -template, then Hugging Face will use a default. For example, the final text in -the `chat` example above with Hugging Face's default template becomes: +In general, for the `chat`, `tools` and `completions` formats, Hugging Face +[chat +templates](https://huggingface.co/docs/transformers/main/en/chat_templating) +are used. This applies the model's chat template by default. If the model does +not have a chat template, then Hugging Face will use a default. For example, +the final text in the `chat` example above with Hugging Face's default template +becomes: ```text <|im_start|>system diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 3d99894c..2b8abf43 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -36,7 +36,10 @@ class ChatDataset(Dataset): def __getitem__(self, idx: int): messages = self._data[idx]["messages"] text = self._tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, + tools=self._data[idx].get("tools", None), + tokenize=False, + add_generation_prompt=True, ) return text diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 24fcc5c6..b15801a5 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -93,9 +93,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) # Encode batch batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] for b in batch: - if b[-1] == tokenizer.eos_token_id: - print("[WARNING] Example already has an EOS token appended") - else: + if b[-1] != tokenizer.eos_token_id: b.append(tokenizer.eos_token_id) lengths = [len(x) for x in batch]